From 717be5940f8ec01e6917058118ff459cda519530 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Sat, 5 Oct 2019 19:44:29 +0300 Subject: [PATCH] Reorganize code a bit, add current directory to sys.path (#198) * reorganize code a bit * add current directory to sys.path * remove PYTHONPATH mention from the docs * linting --- README.md | 6 - mypy_django_plugin/django/context.py | 332 +++++++++--------- mypy_django_plugin/lib/fullnames.py | 2 + mypy_django_plugin/lib/helpers.py | 8 + mypy_django_plugin/main.py | 8 +- mypy_django_plugin/transformers/fields.py | 2 +- .../transformers/init_create.py | 55 +-- mypy_django_plugin/transformers/models.py | 6 +- .../transformers/orm_lookups.py | 51 +++ mypy_django_plugin/transformers/querysets.py | 12 +- .../managers/querysets/test_filter.yml | 43 ++- 11 files changed, 287 insertions(+), 238 deletions(-) create mode 100644 mypy_django_plugin/transformers/orm_lookups.py diff --git a/README.md b/README.md index 83016b5..405205f 100644 --- a/README.md +++ b/README.md @@ -52,12 +52,6 @@ django_settings_module = mysettings Where `mysettings` is a value of `DJANGO_SETTINGS_MODULE` (with or without quotes) -You might also need to explicitly tweak your `PYTHONPATH` the very same way `django` does it internally in case you have troubles with mypy / django plugin not finding your settings module. Try adding the root path of your project to your `PYTHONPATH` environment variable like so: - -```bash -PYTHONPATH=${PYTHONPATH}:${PWD} -``` - Current implementation uses Django runtime to extract models information, so it will crash, if your installed apps `models.py` is not correct. For this same reason, you cannot use `reveal_type` inside global scope of any Python file that will be executed for `django.setup()`. In other words, if your `manage.py runserver` crashes, mypy will crash too. diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index a68cb09..ff1ddd0 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -1,4 +1,5 @@ import os +import sys from collections import defaultdict from contextlib import contextmanager from typing import ( @@ -48,6 +49,9 @@ def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']: with temp_environ(): os.environ['DJANGO_SETTINGS_MODULE'] = settings_module + # add current directory to sys.path + sys.path.append(os.getcwd()) + def noop_class_getitem(cls, key): return cls @@ -73,176 +77,12 @@ def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']: return apps, settings -class DjangoFieldsContext: - def __init__(self, django_context: 'DjangoContext') -> None: - self.django_context = django_context - - def get_attname(self, field: Field) -> str: - attname = field.attname - return attname - - def get_field_nullability(self, field: Field, method: Optional[str]) -> bool: - nullable = field.null - if not nullable and isinstance(field, CharField) and field.blank: - return True - if method == '__init__': - if field.primary_key or isinstance(field, ForeignKey): - return True - if method == 'create': - if isinstance(field, AutoField): - return True - if field.has_default(): - return True - return nullable - - def get_field_set_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType: - """ Get a type of __set__ for this specific Django field. """ - target_field = field - if isinstance(field, ForeignKey): - target_field = field.target_field - - field_info = helpers.lookup_class_typeinfo(api, target_field.__class__) - if field_info is None: - return AnyType(TypeOfAny.from_error) - - field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type', - is_nullable=self.get_field_nullability(field, method)) - if isinstance(target_field, ArrayField): - argument_field_type = self.get_field_set_type(api, target_field.base_field, method=method) - field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type) - return field_set_type - - def get_field_get_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType: - """ Get a type of __get__ for this specific Django field. """ - field_info = helpers.lookup_class_typeinfo(api, field.__class__) - if field_info is None: - return AnyType(TypeOfAny.unannotated) - - is_nullable = self.get_field_nullability(field, method) - if isinstance(field, RelatedField): - related_model_cls = self.django_context.fields_context.get_related_model_cls(field) - - if method == 'values': - primary_key_field = self.django_context.get_primary_key_field(related_model_cls) - return self.get_field_get_type(api, primary_key_field, method=method) - - model_info = helpers.lookup_class_typeinfo(api, related_model_cls) - if model_info is None: - return AnyType(TypeOfAny.unannotated) - - return Instance(model_info, []) - else: - return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', - is_nullable=is_nullable) - - def get_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) -> Type[Model]: - if isinstance(field, RelatedField): - related_model_cls = field.remote_field.model - else: - related_model_cls = field.field.model - - if isinstance(related_model_cls, str): - if related_model_cls == 'self': - # same model - related_model_cls = field.model - elif '.' not in related_model_cls: - # same file model - related_model_fullname = field.model.__module__ + '.' + related_model_cls - related_model_cls = self.django_context.get_model_class_by_fullname(related_model_fullname) - else: - related_model_cls = self.django_context.apps_registry.get_model(related_model_cls) - - return related_model_cls - - class LookupsAreUnsupported(Exception): pass -class DjangoLookupsContext: - def __init__(self, django_context: 'DjangoContext'): - self.django_context = django_context - - def _resolve_field_from_parts(self, field_parts: Iterable[str], model_cls: Type[Model]) -> Field: - currently_observed_model = model_cls - field = None - for field_part in field_parts: - if field_part == 'pk': - field = self.django_context.get_primary_key_field(currently_observed_model) - continue - - field = currently_observed_model._meta.get_field(field_part) - if isinstance(field, RelatedField): - currently_observed_model = field.related_model - model_name = currently_observed_model._meta.model_name - if (model_name is not None - and field_part == (model_name + '_id')): - field = self.django_context.get_primary_key_field(currently_observed_model) - - if isinstance(field, ForeignObjectRel): - currently_observed_model = field.related_model - - assert field is not None - return field - - def resolve_lookup_info_field(self, model_cls: Type[Model], lookup: str) -> Field: - query = Query(model_cls) - lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup) - if lookup_parts: - raise LookupsAreUnsupported() - - return self._resolve_field_from_parts(field_parts, model_cls) - - def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType: - query = Query(model_cls) - try: - lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup) - if is_expression: - return AnyType(TypeOfAny.explicit) - except FieldError as exc: - ctx.api.fail(exc.args[0], ctx.context) - return AnyType(TypeOfAny.from_error) - - field = self._resolve_field_from_parts(field_parts, model_cls) - - lookup_cls = None - if lookup_parts: - lookup = lookup_parts[-1] - lookup_cls = field.get_lookup(lookup) - if lookup_cls is None: - # unknown lookup - return AnyType(TypeOfAny.explicit) - - if lookup_cls is None or isinstance(lookup_cls, Exact): - return self.django_context.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field) - - assert lookup_cls is not None - - lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls) - if lookup_info is None: - return AnyType(TypeOfAny.explicit) - - for lookup_base in helpers.iter_bases(lookup_info): - if lookup_base.args and isinstance(lookup_base.args[0], Instance): - lookup_type: MypyType = lookup_base.args[0] - # if it's Field, consider lookup_type a __get__ of current field - if (isinstance(lookup_type, Instance) - and lookup_type.type.fullname() == fullnames.FIELD_FULLNAME): - field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__) - if field_info is None: - return AnyType(TypeOfAny.explicit) - lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', - is_nullable=field.null) - return lookup_type - - return AnyType(TypeOfAny.explicit) - - class DjangoContext: def __init__(self, django_settings_module: str) -> None: - self.fields_context = DjangoFieldsContext(self) - self.lookups_context = DjangoLookupsContext(self) - self.django_settings_module = django_settings_module apps, settings = initialize_django(self.django_settings_module) @@ -288,7 +128,7 @@ class DjangoContext: if isinstance(field, (RelatedField, ForeignObjectRel)): related_model_cls = field.related_model primary_key_field = self.get_primary_key_field(related_model_cls) - primary_key_type = self.fields_context.get_field_get_type(api, primary_key_field, method='init') + primary_key_type = self.get_field_get_type(api, primary_key_field, method='init') rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls) if rel_model_info is None: @@ -319,13 +159,13 @@ class DjangoContext: # add pk if not abstract=True if not model_cls._meta.abstract: primary_key_field = self.get_primary_key_field(model_cls) - field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method) + field_set_type = self.get_field_set_type(api, primary_key_field, method=method) expected_types['pk'] = field_set_type for field in model_cls._meta.get_fields(): if isinstance(field, Field): field_name = field.attname - field_set_type = self.fields_context.get_field_set_type(api, field, method=method) + field_set_type = self.get_field_set_type(api, field, method=method) expected_types[field_name] = field_set_type if isinstance(field, ForeignKey): @@ -336,7 +176,7 @@ class DjangoContext: expected_types[field_name] = AnyType(TypeOfAny.unannotated) continue - related_model = self.fields_context.get_related_model_cls(field) + related_model = self.get_field_related_model_cls(field) if related_model._meta.proxy_for_model is not None: related_model = related_model._meta.proxy_for_model @@ -345,7 +185,7 @@ class DjangoContext: expected_types[field_name] = AnyType(TypeOfAny.unannotated) continue - is_nullable = self.fields_context.get_field_nullability(field, method) + is_nullable = self.get_field_nullability(field, method) foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info, '_pyi_private_set_type', is_nullable=is_nullable) @@ -375,3 +215,157 @@ class DjangoContext: all_model_bases.add(helpers.get_class_fullname(base_cls)) return all_model_bases + + def get_attname(self, field: Field) -> str: + attname = field.attname + return attname + + def get_field_nullability(self, field: Field, method: Optional[str]) -> bool: + nullable = field.null + if not nullable and isinstance(field, CharField) and field.blank: + return True + if method == '__init__': + if field.primary_key or isinstance(field, ForeignKey): + return True + if method == 'create': + if isinstance(field, AutoField): + return True + if field.has_default(): + return True + return nullable + + def get_field_set_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType: + """ Get a type of __set__ for this specific Django field. """ + target_field = field + if isinstance(field, ForeignKey): + target_field = field.target_field + + field_info = helpers.lookup_class_typeinfo(api, target_field.__class__) + if field_info is None: + return AnyType(TypeOfAny.from_error) + + field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type', + is_nullable=self.get_field_nullability(field, method)) + if isinstance(target_field, ArrayField): + argument_field_type = self.get_field_set_type(api, target_field.base_field, method=method) + field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type) + return field_set_type + + def get_field_get_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType: + """ Get a type of __get__ for this specific Django field. """ + field_info = helpers.lookup_class_typeinfo(api, field.__class__) + if field_info is None: + return AnyType(TypeOfAny.unannotated) + + is_nullable = self.get_field_nullability(field, method) + if isinstance(field, RelatedField): + related_model_cls = self.get_field_related_model_cls(field) + + if method == 'values': + primary_key_field = self.get_primary_key_field(related_model_cls) + return self.get_field_get_type(api, primary_key_field, method=method) + + model_info = helpers.lookup_class_typeinfo(api, related_model_cls) + if model_info is None: + return AnyType(TypeOfAny.unannotated) + + return Instance(model_info, []) + else: + return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', + is_nullable=is_nullable) + + def get_field_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) -> Type[Model]: + if isinstance(field, RelatedField): + related_model_cls = field.remote_field.model + else: + related_model_cls = field.field.model + + if isinstance(related_model_cls, str): + if related_model_cls == 'self': + # same model + related_model_cls = field.model + elif '.' not in related_model_cls: + # same file model + related_model_fullname = field.model.__module__ + '.' + related_model_cls + related_model_cls = self.get_model_class_by_fullname(related_model_fullname) + else: + related_model_cls = self.apps_registry.get_model(related_model_cls) + + return related_model_cls + + def _resolve_field_from_parts(self, field_parts: Iterable[str], model_cls: Type[Model]) -> Field: + currently_observed_model = model_cls + field = None + for field_part in field_parts: + if field_part == 'pk': + field = self.get_primary_key_field(currently_observed_model) + continue + + field = currently_observed_model._meta.get_field(field_part) + if isinstance(field, RelatedField): + currently_observed_model = field.related_model + model_name = currently_observed_model._meta.model_name + if (model_name is not None + and field_part == (model_name + '_id')): + field = self.get_primary_key_field(currently_observed_model) + + if isinstance(field, ForeignObjectRel): + currently_observed_model = field.related_model + + assert field is not None + return field + + def resolve_lookup_into_field(self, model_cls: Type[Model], lookup: str) -> Field: + query = Query(model_cls) + lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup) + if lookup_parts: + raise LookupsAreUnsupported() + + return self._resolve_field_from_parts(field_parts, model_cls) + + def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType: + query = Query(model_cls) + try: + lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup) + if is_expression: + return AnyType(TypeOfAny.explicit) + except FieldError as exc: + ctx.api.fail(exc.args[0], ctx.context) + return AnyType(TypeOfAny.from_error) + + field = self._resolve_field_from_parts(field_parts, model_cls) + + lookup_cls = None + if lookup_parts: + lookup = lookup_parts[-1] + lookup_cls = field.get_lookup(lookup) + if lookup_cls is None: + # unknown lookup + return AnyType(TypeOfAny.explicit) + + if lookup_cls is None or isinstance(lookup_cls, Exact): + return self.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field) + + assert lookup_cls is not None + + lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls) + if lookup_info is None: + return AnyType(TypeOfAny.explicit) + + for lookup_base in helpers.iter_bases(lookup_info): + if lookup_base.args and isinstance(lookup_base.args[0], Instance): + lookup_type: MypyType = lookup_base.args[0] + # if it's Field, consider lookup_type a __get__ of current field + if (isinstance(lookup_type, Instance) + and lookup_type.type.fullname() == fullnames.FIELD_FULLNAME): + field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__) + if field_info is None: + return AnyType(TypeOfAny.explicit) + lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', + is_nullable=field.null) + return lookup_type + + return AnyType(TypeOfAny.explicit) + + def resolve_f_expression_type(self, f_expression_type: Instance) -> MypyType: + return AnyType(TypeOfAny.explicit) diff --git a/mypy_django_plugin/lib/fullnames.py b/mypy_django_plugin/lib/fullnames.py index 15e5ce8..1b80197 100644 --- a/mypy_django_plugin/lib/fullnames.py +++ b/mypy_django_plugin/lib/fullnames.py @@ -37,3 +37,5 @@ RELATED_FIELDS_CLASSES = { MIGRATION_CLASS_FULLNAME = 'django.db.migrations.migration.Migration' OPTIONS_CLASS_FULLNAME = 'django.db.models.options.Options' HTTPREQUEST_CLASS_FULLNAME = 'django.http.request.HttpRequest' + +F_EXPRESSION_FULLNAME = 'django.db.models.expressions.F' diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index c1adda3..40f0a63 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -284,3 +284,11 @@ def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionCont def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool: return (info.fullname() in django_context.model_base_classes or info.has_base(fullnames.MODEL_CLASS_FULLNAME)) + + +def check_types_compatible(ctx: Union[FunctionContext, MethodContext], + *, expected_type: MypyType, actual_type: MypyType, error_message: str) -> None: + api = get_typechecker_api(ctx) + api.check_subtype(actual_type, expected_type, + ctx.context, error_message, + 'got', 'expected') diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 24fb025..e59e87b 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -11,6 +11,7 @@ from mypy.plugin import ( ) from mypy.types import Type as MypyType +import mypy_django_plugin.transformers.orm_lookups from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.transformers import ( @@ -148,13 +149,13 @@ class NewSemanalDjangoPlugin(Plugin): # forward relations for field in self.django_context.get_model_fields(model_class): if isinstance(field, RelatedField): - related_model_cls = self.django_context.fields_context.get_related_model_cls(field) + related_model_cls = self.django_context.get_field_related_model_cls(field) related_model_module = related_model_cls.__module__ if related_model_module != file.fullname(): deps.add(self._new_dependency(related_model_module)) # reverse relations for relation in model_class._meta.related_objects: - related_model_cls = self.django_context.fields_context.get_related_model_cls(relation) + related_model_cls = self.django_context.get_field_related_model_cls(relation) related_model_module = related_model_cls.__module__ if related_model_module != file.fullname(): deps.add(self._new_dependency(related_model_module)) @@ -210,7 +211,8 @@ class NewSemanalDjangoPlugin(Plugin): if class_fullname in manager_classes and method_name == 'create': return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context) if class_fullname in manager_classes and method_name in {'filter', 'get', 'exclude'}: - return partial(init_create.typecheck_queryset_filter, django_context=self.django_context) + return partial(mypy_django_plugin.transformers.orm_lookups.typecheck_queryset_filter, + django_context=self.django_context) return None def get_base_class_hook(self, fullname: str diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index 19d04ff..78b5bc5 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -44,7 +44,7 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context assert isinstance(current_field, RelatedField) - related_model_cls = django_context.fields_context.get_related_model_cls(current_field) + related_model_cls = django_context.get_field_related_model_cls(current_field) related_model = related_model_cls related_model_to_set = related_model_cls diff --git a/mypy_django_plugin/transformers/init_create.py b/mypy_django_plugin/transformers/init_create.py index 33eb13f..2edfdef 100644 --- a/mypy_django_plugin/transformers/init_create.py +++ b/mypy_django_plugin/transformers/init_create.py @@ -6,7 +6,7 @@ from mypy.types import Instance from mypy.types import Type as MypyType from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import fullnames, helpers +from mypy_django_plugin.lib import helpers def get_actual_types(ctx: Union[MethodContext, FunctionContext], @@ -30,12 +30,6 @@ def get_actual_types(ctx: Union[MethodContext, FunctionContext], return actual_types -def check_types_compatible(ctx, *, expected_type, actual_type, error_message): - ctx.api.check_subtype(actual_type, expected_type, - ctx.context, error_message, - 'got', 'expected') - - def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext, model_cls: Type[Model], method: str) -> MypyType: typechecker_api = helpers.get_typechecker_api(ctx) @@ -48,11 +42,11 @@ def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_co model_cls.__name__), ctx.context) continue - check_types_compatible(ctx, - expected_type=expected_types[actual_name], - actual_type=actual_type, - error_message='Incompatible type for "{}" of "{}"'.format(actual_name, - model_cls.__name__)) + helpers.check_types_compatible(ctx, + expected_type=expected_types[actual_name], + actual_type=actual_type, + error_message='Incompatible type for "{}" of "{}"'.format(actual_name, + model_cls.__name__)) return ctx.default_return_type @@ -79,40 +73,3 @@ def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: Djan return ctx.default_return_type return typecheck_model_method(ctx, django_context, model_cls, 'create') - - -def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType: - lookup_kwargs = ctx.arg_names[1] - provided_lookup_types = ctx.arg_types[1] - - assert isinstance(ctx.type, Instance) - - if not ctx.type.args or not isinstance(ctx.type.args[0], Instance): - return ctx.default_return_type - - model_cls_fullname = ctx.type.args[0].type.fullname() - model_cls = django_context.get_model_class_by_fullname(model_cls_fullname) - if model_cls is None: - return ctx.default_return_type - - for lookup_kwarg, provided_type in zip(lookup_kwargs, provided_lookup_types): - if lookup_kwarg is None: - continue - # Combinables are not supported yet - if (isinstance(provided_type, Instance) - and provided_type.type.has_base('django.db.models.expressions.Combinable')): - continue - - lookup_type = django_context.lookups_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg) - # Managers as provided_type is not supported yet - if (isinstance(provided_type, Instance) - and helpers.has_any_of_bases(provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME, - fullnames.QUERYSET_CLASS_FULLNAME))): - return ctx.default_return_type - - check_types_compatible(ctx, - expected_type=lookup_type, - actual_type=provided_type, - error_message=f'Incompatible type for lookup {lookup_kwarg!r}:') - - return ctx.default_return_type diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 9155aa3..1a618a6 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -99,10 +99,10 @@ class AddRelatedModelsId(ModelClassInitializer): def run_with_model_cls(self, model_cls: Type[Model]) -> None: for field in model_cls._meta.get_fields(): if isinstance(field, ForeignKey): - related_model_cls = self.django_context.fields_context.get_related_model_cls(field) + related_model_cls = self.django_context.get_field_related_model_cls(field) rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls) field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__) - is_nullable = self.django_context.fields_context.get_field_nullability(field, None) + is_nullable = self.django_context.get_field_nullability(field, None) set_type, get_type = get_field_descriptor_types(field_info, is_nullable) self.add_new_node_to_model_class(field.attname, Instance(field_info, [set_type, get_type])) @@ -162,7 +162,7 @@ class AddRelatedManagers(ModelClassInitializer): # no reverse accessor continue - related_model_cls = self.django_context.fields_context.get_related_model_cls(relation) + related_model_cls = self.django_context.get_field_related_model_cls(relation) related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls) if isinstance(relation, OneToOneRel): diff --git a/mypy_django_plugin/transformers/orm_lookups.py b/mypy_django_plugin/transformers/orm_lookups.py new file mode 100644 index 0000000..e1ed913 --- /dev/null +++ b/mypy_django_plugin/transformers/orm_lookups.py @@ -0,0 +1,51 @@ +from mypy.plugin import MethodContext +from mypy.types import AnyType, Instance +from mypy.types import Type as MypyType +from mypy.types import TypeOfAny + +from mypy_django_plugin.django.context import DjangoContext +from mypy_django_plugin.lib import fullnames, helpers + + +def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType: + lookup_kwargs = ctx.arg_names[1] + provided_lookup_types = ctx.arg_types[1] + + assert isinstance(ctx.type, Instance) + + if not ctx.type.args or not isinstance(ctx.type.args[0], Instance): + return ctx.default_return_type + + model_cls_fullname = ctx.type.args[0].type.fullname() + model_cls = django_context.get_model_class_by_fullname(model_cls_fullname) + if model_cls is None: + return ctx.default_return_type + + for lookup_kwarg, provided_type in zip(lookup_kwargs, provided_lookup_types): + if lookup_kwarg is None: + continue + if (isinstance(provided_type, Instance) + and provided_type.type.has_base('django.db.models.expressions.Combinable')): + provided_type = resolve_combinable_type(provided_type, django_context) + + lookup_type = django_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg) + # Managers as provided_type is not supported yet + if (isinstance(provided_type, Instance) + and helpers.has_any_of_bases(provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME, + fullnames.QUERYSET_CLASS_FULLNAME))): + return ctx.default_return_type + + helpers.check_types_compatible(ctx, + expected_type=lookup_type, + actual_type=provided_type, + error_message=f'Incompatible type for lookup {lookup_kwarg!r}:') + + return ctx.default_return_type + + +def resolve_combinable_type(combinable_type: Instance, django_context: DjangoContext) -> MypyType: + if combinable_type.type.fullname() != fullnames.F_EXPRESSION_FULLNAME: + # Combinables aside from F expressions are unsupported + return AnyType(TypeOfAny.explicit) + + return django_context.resolve_f_expression_type(combinable_type) diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index 93c7d97..b22bf4a 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -40,7 +40,7 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType: def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], *, method: str, lookup: str) -> Optional[MypyType]: try: - lookup_field = django_context.lookups_context.resolve_lookup_info_field(model_cls, lookup) + lookup_field = django_context.resolve_lookup_into_field(model_cls, lookup) except FieldError as exc: ctx.api.fail(exc.args[0], ctx.context) return None @@ -48,11 +48,11 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext return AnyType(TypeOfAny.explicit) if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup: - related_model_cls = django_context.fields_context.get_related_model_cls(lookup_field) + related_model_cls = django_context.get_field_related_model_cls(lookup_field) lookup_field = django_context.get_primary_key_field(related_model_cls) - field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx), - lookup_field, method=method) + field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx), + lookup_field, method=method) return field_get_type @@ -73,8 +73,8 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, elif named: column_types: 'OrderedDict[str, MypyType]' = OrderedDict() for field in django_context.get_model_fields(model_cls): - column_type = django_context.fields_context.get_field_get_type(typechecker_api, field, - method='values_list') + column_type = django_context.get_field_get_type(typechecker_api, field, + method='values_list') column_types[field.attname] = column_type return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) else: diff --git a/test-data/typecheck/managers/querysets/test_filter.yml b/test-data/typecheck/managers/querysets/test_filter.yml index af01853..a084195 100644 --- a/test-data/typecheck/managers/querysets/test_filter.yml +++ b/test-data/typecheck/managers/querysets/test_filter.yml @@ -212,4 +212,45 @@ class User(models.Model): pass class Profile(models.Model): - user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile') \ No newline at end of file + user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile') + + +# TODO +- case: f_expression_simple_case + main: | + from myapp.models import User + from django.db import models + User.objects.filter(username=models.F('username2')) + User.objects.filter(username=models.F('age')) + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + + class User(models.Model): + username = models.TextField() + username2 = models.TextField() + + age = models.IntegerField() + + +# TODO +- case: f_expression_with_expression_math_is_not_supported + main: | + from myapp.models import User + from django.db import models + User.objects.filter(username=models.F('username2') + 'hello') + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class User(models.Model): + username = models.TextField() + username2 = models.TextField() + age = models.IntegerField()