mirror of
https://github.com/davidhalter/jedi.git
synced 2025-12-06 22:14:27 +08:00
Simple tests of Django plugin are added.
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Module provides infering of Django model fields.
|
Module is used to infer Django model fields.
|
||||||
"""
|
"""
|
||||||
from jedi.inference.base_value import LazyValueWrapper
|
from jedi.inference.base_value import LazyValueWrapper
|
||||||
from jedi.inference.utils import safe_property
|
from jedi.inference.utils import safe_property
|
||||||
@@ -10,9 +10,8 @@ from jedi.inference.value.instance import TreeInstance
|
|||||||
|
|
||||||
def new_dict_filter(cls):
|
def new_dict_filter(cls):
|
||||||
filter_ = ParserTreeFilter(parent_context=cls.as_context())
|
filter_ = ParserTreeFilter(parent_context=cls.as_context())
|
||||||
return [DictFilter({
|
res = {f.string_name: _infer_field(cls, f) for f in filter_.values()}
|
||||||
f.string_name: _infer_field(cls, f) for f in filter_.values()
|
return [DictFilter({x: y for x, y in res.items() if y is not None})]
|
||||||
})]
|
|
||||||
|
|
||||||
|
|
||||||
class DjangoModelField(LazyValueWrapper):
|
class DjangoModelField(LazyValueWrapper):
|
||||||
@@ -30,6 +29,7 @@ class DjangoModelField(LazyValueWrapper):
|
|||||||
obj, = self._cls.execute_with_values()
|
obj, = self._cls.execute_with_values()
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
mapping = {
|
mapping = {
|
||||||
'IntegerField': (None, 'int'),
|
'IntegerField': (None, 'int'),
|
||||||
'BigIntegerField': (None, 'int'),
|
'BigIntegerField': (None, 'int'),
|
||||||
@@ -48,38 +48,41 @@ mapping = {
|
|||||||
'DateTimeField': ('datetime', 'datetime'),
|
'DateTimeField': ('datetime', 'datetime'),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _infer_field(cls, field):
|
|
||||||
field_tree_instance, = field.infer()
|
|
||||||
|
|
||||||
try:
|
def _infer_scalar_field(cls, field, field_tree_instance):
|
||||||
|
if field_tree_instance.name.string_name not in mapping:
|
||||||
|
return None
|
||||||
|
|
||||||
module_name, attribute_name = mapping[field_tree_instance.name.string_name]
|
module_name, attribute_name = mapping[field_tree_instance.name.string_name]
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
if module_name is None:
|
if module_name is None:
|
||||||
module = cls.inference_state.builtins_module
|
module = cls.inference_state.builtins_module
|
||||||
else:
|
else:
|
||||||
module = cls.inference_state.import_module((module_name,))
|
module = cls.inference_state.import_module((module_name,))
|
||||||
|
|
||||||
attribute, = module.py__getattribute__(attribute_name)
|
attribute, = module.py__getattribute__(attribute_name)
|
||||||
return DjangoModelField(attribute, field).name
|
return DjangoModelField(attribute, field).name
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_field(cls, field):
|
||||||
|
field_tree_instance, = field.infer()
|
||||||
|
scalar_field = _infer_scalar_field(cls, field, field_tree_instance)
|
||||||
|
if scalar_field:
|
||||||
|
return scalar_field
|
||||||
|
|
||||||
if field_tree_instance.name.string_name == 'ForeignKey':
|
if field_tree_instance.name.string_name == 'ForeignKey':
|
||||||
if isinstance(field_tree_instance, TreeInstance):
|
if isinstance(field_tree_instance, TreeInstance):
|
||||||
argument_iterator = field_tree_instance._arguments.unpack()
|
argument_iterator = field_tree_instance._arguments.unpack()
|
||||||
key, lazy_values = next(argument_iterator, (None, None))
|
key, lazy_values = next(argument_iterator, (None, None))
|
||||||
if key is None and lazy_values is not None:
|
if key is None and lazy_values is not None:
|
||||||
# TODO: it has only one element in current state. Handle rest of elements.
|
|
||||||
for value in lazy_values.infer():
|
for value in lazy_values.infer():
|
||||||
string = value.get_safe_value(default=None)
|
|
||||||
if value.name.string_name == 'str':
|
if value.name.string_name == 'str':
|
||||||
foreign_key_class_name = value._compiled_obj.get_safe_value()
|
foreign_key_class_name = value.get_safe_value()
|
||||||
# TODO: it has only one element in current state. Handle rest of elements.
|
|
||||||
for v in cls.parent_context.py__getattribute__(foreign_key_class_name):
|
for v in cls.parent_context.py__getattribute__(foreign_key_class_name):
|
||||||
return DjangoModelField(v, field).name
|
return DjangoModelField(v, field).name
|
||||||
else:
|
else:
|
||||||
return DjangoModelField(value, field).name
|
return DjangoModelField(value, field).name
|
||||||
|
|
||||||
raise Exception('Should be handled')
|
print('TODO: {}'.format(field))
|
||||||
|
|
||||||
|
|
||||||
def get_metaclass_filters(func):
|
def get_metaclass_filters(func):
|
||||||
|
|||||||
64
test/test_inference/test_django.py
Normal file
64
test/test_inference/test_django.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import pytest
|
||||||
|
import datetime
|
||||||
|
import decimal
|
||||||
|
|
||||||
|
source_tpl_basic_types = '''
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class BusinessModel(models.Model):
|
||||||
|
{0} = {1}
|
||||||
|
|
||||||
|
p1 = BusinessModel()
|
||||||
|
p1_field = p1.{0}
|
||||||
|
p1_field.'''
|
||||||
|
|
||||||
|
|
||||||
|
source_tpl_foreign_key = '''
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class Category(models.Model):
|
||||||
|
category_name = models.CharField()
|
||||||
|
|
||||||
|
class BusinessModel(models.Model):
|
||||||
|
category = models.ForeignKey(Category)
|
||||||
|
|
||||||
|
p1 = BusinessModel()
|
||||||
|
p1_field = p1.category
|
||||||
|
p1_field.'''
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('field_name, field_model_type, expected_fields', [
|
||||||
|
('integer_field', 'models.IntegerField()', dir(int)),
|
||||||
|
('big_integer_field', 'models.BigIntegerField()', dir(int)),
|
||||||
|
('positive_integer_field', 'models.PositiveIntegerField()', dir(int)),
|
||||||
|
('small_integer_field', 'models.SmallIntegerField()', dir(int)),
|
||||||
|
('char_field', 'models.CharField()', dir(str)),
|
||||||
|
('text_field', 'models.TextField()', dir(str)),
|
||||||
|
('email_field', 'models.EmailField()', dir(str)),
|
||||||
|
('float_field', 'models.FloatField()', dir(float)),
|
||||||
|
('binary_field', 'models.BinaryField()', dir(bytes)),
|
||||||
|
('boolean_field', 'models.BooleanField()', dir(bool)),
|
||||||
|
('decimal_field', 'models.DecimalField()', dir(decimal.Decimal)),
|
||||||
|
('time_field', 'models.TimeField()', dir(datetime.time)),
|
||||||
|
('duration_field', 'models.DurationField()', dir(datetime.timedelta)),
|
||||||
|
('date_field', 'models.DateField()', dir(datetime.date)),
|
||||||
|
('date_time_field', 'models.DateTimeField()', dir(datetime.datetime)),
|
||||||
|
])
|
||||||
|
def test_basic_types(
|
||||||
|
field_name,
|
||||||
|
field_model_type,
|
||||||
|
expected_fields,
|
||||||
|
Script,
|
||||||
|
):
|
||||||
|
source = source_tpl_basic_types.format(field_name, field_model_type)
|
||||||
|
result = Script(source).complete()
|
||||||
|
result = {x.name for x in result}
|
||||||
|
expected_fields_public = [x for x in expected_fields if x[0] != '_']
|
||||||
|
for field in expected_fields_public:
|
||||||
|
assert field in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_foreign_key(Script):
|
||||||
|
result = Script(source_tpl_foreign_key).complete()
|
||||||
|
result = {x.name for x in result}
|
||||||
|
assert 'category_name' in result
|
||||||
Reference in New Issue
Block a user