diff --git a/jedi/plugins/django.py b/jedi/plugins/django.py index cacb9e75..03a14c5d 100644 --- a/jedi/plugins/django.py +++ b/jedi/plugins/django.py @@ -1,5 +1,5 @@ """ -Module provides infering of Django model fields. +Module is used to infer Django model fields. """ from jedi.inference.base_value import LazyValueWrapper from jedi.inference.utils import safe_property @@ -10,9 +10,8 @@ from jedi.inference.value.instance import TreeInstance def new_dict_filter(cls): filter_ = ParserTreeFilter(parent_context=cls.as_context()) - return [DictFilter({ - f.string_name: _infer_field(cls, f) for f in filter_.values() - })] + res = {f.string_name: _infer_field(cls, f) for f in filter_.values()} + return [DictFilter({x: y for x, y in res.items() if y is not None})] class DjangoModelField(LazyValueWrapper): @@ -30,6 +29,7 @@ class DjangoModelField(LazyValueWrapper): obj, = self._cls.execute_with_values() return obj + mapping = { 'IntegerField': (None, 'int'), 'BigIntegerField': (None, 'int'), @@ -48,38 +48,41 @@ mapping = { 'DateTimeField': ('datetime', 'datetime'), } + +def _infer_scalar_field(cls, field, field_tree_instance): + if field_tree_instance.name.string_name not in mapping: + return None + + module_name, attribute_name = mapping[field_tree_instance.name.string_name] + if module_name is None: + module = cls.inference_state.builtins_module + else: + module = cls.inference_state.import_module((module_name,)) + + attribute, = module.py__getattribute__(attribute_name) + return DjangoModelField(attribute, field).name + + def _infer_field(cls, field): field_tree_instance, = field.infer() - - try: - module_name, attribute_name = mapping[field_tree_instance.name.string_name] - except KeyError: - pass - else: - if module_name is None: - module = cls.inference_state.builtins_module - else: - module = cls.inference_state.import_module((module_name,)) - attribute, = module.py__getattribute__(attribute_name) - return DjangoModelField(attribute, field).name + scalar_field = _infer_scalar_field(cls, field, field_tree_instance) + if scalar_field: + return scalar_field if field_tree_instance.name.string_name == 'ForeignKey': if isinstance(field_tree_instance, TreeInstance): argument_iterator = field_tree_instance._arguments.unpack() key, lazy_values = next(argument_iterator, (None, None)) if key is None and lazy_values is not None: - # TODO: it has only one element in current state. Handle rest of elements. for value in lazy_values.infer(): - string = value.get_safe_value(default=None) if value.name.string_name == 'str': - foreign_key_class_name = value._compiled_obj.get_safe_value() - # TODO: it has only one element in current state. Handle rest of elements. + foreign_key_class_name = value.get_safe_value() for v in cls.parent_context.py__getattribute__(foreign_key_class_name): return DjangoModelField(v, field).name else: return DjangoModelField(value, field).name - raise Exception('Should be handled') + print('TODO: {}'.format(field)) def get_metaclass_filters(func): diff --git a/test/test_inference/test_django.py b/test/test_inference/test_django.py new file mode 100644 index 00000000..16c9ee52 --- /dev/null +++ b/test/test_inference/test_django.py @@ -0,0 +1,64 @@ +import pytest +import datetime +import decimal + +source_tpl_basic_types = ''' +from django.db import models + +class BusinessModel(models.Model): + {0} = {1} + +p1 = BusinessModel() +p1_field = p1.{0} +p1_field.''' + + +source_tpl_foreign_key = ''' +from django.db import models + +class Category(models.Model): + category_name = models.CharField() + +class BusinessModel(models.Model): + category = models.ForeignKey(Category) + +p1 = BusinessModel() +p1_field = p1.category +p1_field.''' + + +@pytest.mark.parametrize('field_name, field_model_type, expected_fields', [ + ('integer_field', 'models.IntegerField()', dir(int)), + ('big_integer_field', 'models.BigIntegerField()', dir(int)), + ('positive_integer_field', 'models.PositiveIntegerField()', dir(int)), + ('small_integer_field', 'models.SmallIntegerField()', dir(int)), + ('char_field', 'models.CharField()', dir(str)), + ('text_field', 'models.TextField()', dir(str)), + ('email_field', 'models.EmailField()', dir(str)), + ('float_field', 'models.FloatField()', dir(float)), + ('binary_field', 'models.BinaryField()', dir(bytes)), + ('boolean_field', 'models.BooleanField()', dir(bool)), + ('decimal_field', 'models.DecimalField()', dir(decimal.Decimal)), + ('time_field', 'models.TimeField()', dir(datetime.time)), + ('duration_field', 'models.DurationField()', dir(datetime.timedelta)), + ('date_field', 'models.DateField()', dir(datetime.date)), + ('date_time_field', 'models.DateTimeField()', dir(datetime.datetime)), +]) +def test_basic_types( + field_name, + field_model_type, + expected_fields, + Script, +): + source = source_tpl_basic_types.format(field_name, field_model_type) + result = Script(source).complete() + result = {x.name for x in result} + expected_fields_public = [x for x in expected_fields if x[0] != '_'] + for field in expected_fields_public: + assert field in result + + +def test_foreign_key(Script): + result = Script(source_tpl_foreign_key).complete() + result = {x.name for x in result} + assert 'category_name' in result