From 4c218556415099118bf751923e7dc5f648df9700 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Thu, 25 Jul 2019 18:52:51 +0300 Subject: [PATCH] fix mypy errors --- django-stubs/db/models/fields/related.pyi | 4 +- mypy_django_plugin/django/context.py | 50 +++++---- mypy_django_plugin/lib/helpers.py | 42 ++++++-- mypy_django_plugin/main.py | 12 +-- mypy_django_plugin/transformers/fields.py | 27 +++-- .../transformers/init_create.py | 17 +-- mypy_django_plugin/transformers/meta.py | 8 +- mypy_django_plugin/transformers/models.py | 4 +- mypy_django_plugin/transformers/querysets.py | 101 ++++++++---------- mypy_django_plugin/transformers/request.py | 2 +- mypy_django_plugin/transformers/settings.py | 14 ++- .../managers/querysets/test_values_list.yml | 5 + 12 files changed, 168 insertions(+), 118 deletions(-) diff --git a/django-stubs/db/models/fields/related.pyi b/django-stubs/db/models/fields/related.pyi index 0e86e14..6192cd3 100644 --- a/django-stubs/db/models/fields/related.pyi +++ b/django-stubs/db/models/fields/related.pyi @@ -45,9 +45,9 @@ class RelatedField(FieldCacheMixin, Field[_ST, _GT]): one_to_one: bool = ... many_to_many: bool = ... many_to_one: bool = ... - @property - def related_model(self) -> Union[Type[Model], str]: ... opts: Any = ... + @property + def related_model(self) -> Type[Model]: ... def get_forward_related_filter(self, obj: Model) -> Dict[str, Union[int, UUID]]: ... def get_reverse_related_filter(self, obj: Model) -> Q: ... @property diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index fd7bb6c..da2618a 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -3,6 +3,8 @@ from collections import defaultdict from contextlib import contextmanager from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Type +from mypy.nodes import TypeInfo + from django.contrib.postgres.fields import ArrayField from django.core.exceptions import FieldError from django.db.models.base import Model @@ -12,7 +14,7 @@ from django.db.models.fields.reverse_related import ForeignObjectRel from django.db.models.sql.query import Query from django.utils.functional import cached_property from mypy.checker import TypeChecker -from mypy.types import Instance +from mypy.types import Instance, AnyType, TypeOfAny from mypy.types import Type as MypyType from mypy_django_plugin.lib import helpers @@ -42,14 +44,14 @@ def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']: from django.db import models - models.QuerySet.__class_getitem__ = classmethod(noop_class_getitem) - models.Manager.__class_getitem__ = classmethod(noop_class_getitem) + models.QuerySet.__class_getitem__ = classmethod(noop_class_getitem) # type: ignore + models.Manager.__class_getitem__ = classmethod(noop_class_getitem) # type: ignore from django.conf import settings from django.apps import apps - apps.get_models.cache_clear() - apps.get_swappable_settings_name.cache_clear() + apps.get_models.cache_clear() # type: ignore + apps.get_swappable_settings_name.cache_clear() # type: ignore if not settings.configured: settings._setup() @@ -84,28 +86,37 @@ class DjangoFieldsContext: return True return nullable - def get_field_set_type(self, api: TypeChecker, field: Field, method: str) -> MypyType: + def get_field_set_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType: + """ Get a type of __set__ for this specific Django field. """ target_field = field if isinstance(field, ForeignKey): target_field = field.target_field field_info = helpers.lookup_class_typeinfo(api, target_field.__class__) + if field_info is None: + return AnyType(TypeOfAny.from_error) + field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type', is_nullable=self.get_field_nullability(field, method)) if isinstance(target_field, ArrayField): - argument_field_type = self.get_field_set_type(api, target_field.base_field, method) + argument_field_type = self.get_field_set_type(api, target_field.base_field, method=method) field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type) return field_set_type - def get_field_get_type(self, api: TypeChecker, field: Field, method: str) -> MypyType: + def get_field_get_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType: + """ Get a type of __get__ for this specific Django field. """ field_info = helpers.lookup_class_typeinfo(api, field.__class__) + assert isinstance(field_info, TypeInfo) + is_nullable = self.get_field_nullability(field, method) if isinstance(field, RelatedField): if method == 'values': primary_key_field = self.django_context.get_primary_key_field(field.related_model) - return self.get_field_get_type(api, primary_key_field, method) + return self.get_field_get_type(api, primary_key_field, method=method) model_info = helpers.lookup_class_typeinfo(api, field.related_model) + assert isinstance(model_info, TypeInfo) + return Instance(model_info, []) else: return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', @@ -136,6 +147,8 @@ class DjangoLookupsContext: if isinstance(current_field, RelatedField): currently_observed_model = current_field.related_model + # if it is None, solve_lookup_type() will fail earlier + assert current_field is not None return current_field @@ -146,12 +159,9 @@ class DjangoContext: self.django_settings_module = django_settings_module - self.apps_registry: Optional[Dict[str, str]] = None - self.settings: LazySettings = None - if self.django_settings_module: - apps, settings = initialize_django(self.django_settings_module) - self.apps_registry = apps - self.settings = settings + apps, settings = initialize_django(self.django_settings_module) + self.apps_registry = apps + self.settings = settings @cached_property def model_modules(self) -> Dict[str, List[Type[Model]]]: @@ -170,6 +180,7 @@ class DjangoContext: for model_cls in self.model_modules.get(module, []): if model_cls.__name__ == model_cls_name: return model_cls + return None def get_model_fields(self, model_cls: Type[Model]) -> Iterator[Field]: for field in model_cls._meta.get_fields(): @@ -188,30 +199,33 @@ class DjangoContext: return field raise ValueError('No primary key defined') - def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], method: str) -> Dict[str, MypyType]: + def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], *, method: str) -> Dict[str, MypyType]: from django.contrib.contenttypes.fields import GenericForeignKey expected_types = {} # add pk primary_key_field = self.get_primary_key_field(model_cls) - field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method) + field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method) expected_types['pk'] = field_set_type for field in model_cls._meta.get_fields(): if isinstance(field, Field): field_name = field.attname - field_set_type = self.fields_context.get_field_set_type(api, field, method) + field_set_type = self.fields_context.get_field_set_type(api, field, method=method) expected_types[field_name] = field_set_type if isinstance(field, ForeignKey): field_name = field.name foreign_key_info = helpers.lookup_class_typeinfo(api, field.__class__) + assert isinstance(foreign_key_info, TypeInfo) related_model = field.related_model if related_model._meta.proxy_for_model: related_model = field.related_model._meta.proxy_for_model related_model_info = helpers.lookup_class_typeinfo(api, related_model) + assert isinstance(related_model_info, TypeInfo) + is_nullable = self.fields_context.get_field_nullability(field, method) foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info, '_pyi_private_set_type', diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 47e5838..9b9fa55 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -1,17 +1,13 @@ from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union, cast from mypy import checker from mypy.checker import TypeChecker from mypy.mro import calculate_mro -from mypy.nodes import ( - GDEF, MDEF, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, - SymbolTableNode, TypeInfo, Var, -) -from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext -from mypy.types import AnyType, Instance, NoneTyp, TupleType -from mypy.types import Type as MypyType -from mypy.types import TypedDictType, TypeOfAny, UnionType +from mypy.nodes import (Block, ClassDef, Expression, GDEF, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, + SymbolTable, SymbolTableNode, TypeInfo, Var) +from mypy.plugin import AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext +from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType if TYPE_CHECKING: from mypy_django_plugin.django.context import DjangoContext @@ -119,9 +115,17 @@ def has_any_of_bases(info: TypeInfo, bases: Set[str]) -> bool: def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType: - node = type_info.get(private_field_name).node + """ Return declared type of type_info's private_field_name (used for private Field attributes)""" + sym = type_info.get(private_field_name) + if sym is None: + return AnyType(TypeOfAny.unannotated) + + node = sym.node if isinstance(node, Var): descriptor_type = node.type + if descriptor_type is None: + return AnyType(TypeOfAny.unannotated) + if is_nullable: descriptor_type = make_optional(descriptor_type) return descriptor_type @@ -167,8 +171,18 @@ def add_new_class_for_module(module: MypyFile, name: str, bases: List[Instance], return new_typeinfo +def get_current_module(api: TypeChecker) -> MypyFile: + current_module = None + for item in reversed(api.scope.stack): + if isinstance(item, MypyFile): + current_module = item + break + assert current_module is not None + return current_module + + def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]') -> TupleType: - current_module = api.scope.stack[0] + current_module = get_current_module(api) namedtuple_info = add_new_class_for_module(current_module, name, bases=[api.named_generic_type('typing.NamedTuple', [])], fields=fields) @@ -225,3 +239,9 @@ def resolve_string_attribute_value(attr_expr: Expression, ctx: Union[FunctionCon ctx.api.fail(f'Expression of type {type(attr_expr).__name__!r} is not supported', ctx.context) return None + + +def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker: + if not isinstance(ctx.api, TypeChecker): + raise ValueError('Not a TypeChecker') + return cast(TypeChecker, ctx.api) diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index fe8938d..2c76bf1 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -54,7 +54,7 @@ def extract_django_settings_module(config_file_path: Optional[str]) -> str: errors.raise_error() parser = configparser.ConfigParser() - parser.read(config_file_path) + parser.read(config_file_path) # type: ignore if not parser.has_section('mypy.plugins.django-stubs'): errors.report(0, None, "'django_settings_module' is not set: no section [mypy.plugins.django-stubs]", @@ -174,6 +174,7 @@ class NewSemanalDjangoPlugin(Plugin): if info.has_base(fullnames.MODEL_CLASS_FULLNAME): return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context) + return None def get_method_hook(self, fullname: str ) -> Optional[Callable[[MethodContext], MypyType]]: @@ -206,6 +207,7 @@ class NewSemanalDjangoPlugin(Plugin): manager_classes = self._get_current_manager_bases() if class_fullname in manager_classes and method_name == 'create': return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context) + return None def get_base_class_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: @@ -217,6 +219,7 @@ class NewSemanalDjangoPlugin(Plugin): if fullname in self._get_current_form_bases(): return transform_form_class + return None def get_attribute_hook(self, fullname: str ) -> Optional[Callable[[AttributeContext], MypyType]]: @@ -228,12 +231,7 @@ class NewSemanalDjangoPlugin(Plugin): info = self._get_typeinfo_or_none(class_name) if info and info.has_base(fullnames.HTTPREQUEST_CLASS_FULLNAME) and attr_name == 'user': return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context) - - # def get_type_analyze_hook(self, fullname: str - # ( ): - # info = self._get_typeinfo_or_none(fullname) - # if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): - # return partial(querysets.set_first_generic_param_as_default_for_second, fullname=fullname) + return None def plugin(version): diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index 3e2504a..23d095e 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -2,7 +2,7 @@ from typing import Optional, Tuple, cast from django.db.models.fields import Field from django.db.models.fields.related import RelatedField -from mypy.nodes import AssignmentStmt, TypeInfo +from mypy.nodes import AssignmentStmt, TypeInfo, NameExpr from mypy.plugin import FunctionContext from mypy.types import AnyType, Instance from mypy.types import Type as MypyType @@ -13,15 +13,16 @@ from mypy_django_plugin.lib import fullnames, helpers def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]: - outer_model_info = ctx.api.scope.active_class() - assert isinstance(outer_model_info, TypeInfo) - if not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME): + outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() + if (outer_model_info is None + or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)): return None field_name = None for stmt in outer_model_info.defn.defs.body: if isinstance(stmt, AssignmentStmt): if stmt.rvalue == ctx.context: + assert isinstance(stmt.lvalues[0], NameExpr) field_name = stmt.lvalues[0].name break if field_name is None: @@ -46,8 +47,13 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context if related_model_to_set._meta.proxy_for_model: related_model_to_set = related_model._meta.proxy_for_model - related_model_info = helpers.lookup_class_typeinfo(ctx.api, related_model) - related_model_to_set_info = helpers.lookup_class_typeinfo(ctx.api, related_model_to_set) + typechecker_api = helpers.get_typechecker_api(ctx) + + related_model_info = helpers.lookup_class_typeinfo(typechecker_api, related_model) + assert isinstance(related_model_info, TypeInfo) + + related_model_to_set_info = helpers.lookup_class_typeinfo(typechecker_api, related_model_to_set) + assert isinstance(related_model_to_set_info, TypeInfo) default_related_field_type = set_descriptor_types_for_field(ctx) # replace Any with referred_to_type @@ -68,7 +74,12 @@ def get_field_descriptor_types(field_info: TypeInfo, is_nullable: bool) -> Tuple def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance: default_return_type = cast(Instance, ctx.default_return_type) - is_nullable = helpers.parse_bool(helpers.get_call_argument_by_name(ctx, 'null')) + + is_nullable = False + null_expr = helpers.get_call_argument_by_name(ctx, 'null') + if null_expr is not None: + is_nullable = helpers.parse_bool(null_expr) or False + set_type, get_type = get_field_descriptor_types(default_return_type.type, is_nullable) return helpers.reparametrize_instance(default_return_type, [set_type, get_type]) @@ -92,7 +103,7 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan default_return_type = ctx.default_return_type assert isinstance(default_return_type, Instance) - outer_model_info = ctx.api.scope.active_class() + outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() if not outer_model_info or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME): # not inside models.Model class return ctx.default_return_type diff --git a/mypy_django_plugin/transformers/init_create.py b/mypy_django_plugin/transformers/init_create.py index d02945c..e725948 100644 --- a/mypy_django_plugin/transformers/init_create.py +++ b/mypy_django_plugin/transformers/init_create.py @@ -2,10 +2,10 @@ from typing import List, Tuple, Type, Union from django.db.models.base import Model from mypy.plugin import FunctionContext, MethodContext -from mypy.types import Instance -from mypy.types import Type as MypyType +from mypy.types import Instance, Type as MypyType from mypy_django_plugin.django.context import DjangoContext +from mypy_django_plugin.lib import helpers def get_actual_types(ctx: Union[MethodContext, FunctionContext], @@ -31,7 +31,8 @@ def get_actual_types(ctx: Union[MethodContext, FunctionContext], def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext, model_cls: Type[Model], method: str) -> MypyType: - expected_types = django_context.get_expected_types(ctx.api, model_cls, method) + typechecker_api = helpers.get_typechecker_api(ctx) + expected_types = django_context.get_expected_types(typechecker_api, model_cls, method=method) expected_keys = [key for key in expected_types.keys() if key != 'pk'] for actual_name, actual_type in get_actual_types(ctx, expected_keys): @@ -40,11 +41,11 @@ def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_co model_cls.__name__), ctx.context) continue - ctx.api.check_subtype(actual_type, expected_types[actual_name], - ctx.context, - 'Incompatible type for "{}" of "{}"'.format(actual_name, - model_cls.__name__), - 'got', 'expected') + typechecker_api.check_subtype(actual_type, expected_types[actual_name], + ctx.context, + 'Incompatible type for "{}" of "{}"'.format(actual_name, + model_cls.__name__), + 'got', 'expected') return ctx.default_return_type diff --git a/mypy_django_plugin/transformers/meta.py b/mypy_django_plugin/transformers/meta.py index 1eb39d2..3046cbd 100644 --- a/mypy_django_plugin/transformers/meta.py +++ b/mypy_django_plugin/transformers/meta.py @@ -1,4 +1,5 @@ from django.core.exceptions import FieldDoesNotExist +from mypy.nodes import TypeInfo from mypy.plugin import MethodContext from mypy.types import AnyType, Instance from mypy.types import Type as MypyType @@ -9,11 +10,16 @@ from mypy_django_plugin.lib import fullnames, helpers def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType: - field_info = helpers.lookup_fully_qualified_typeinfo(ctx.api, field_fullname) + field_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx), + field_fullname) + assert isinstance(field_info, TypeInfo) return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)]) def return_proper_field_type_from_get_field(ctx: MethodContext, django_context: DjangoContext) -> MypyType: + # Options instance + assert isinstance(ctx.type, Instance) + model_type = ctx.type.args[0] if not isinstance(model_type, Instance): return _get_field_instance(ctx, fullnames.FIELD_FULLNAME) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index c5161e3..9318b0c 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -118,8 +118,8 @@ class AddManagers(ModelClassInitializer): manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname) if manager_name not in self.model_classdef.info.names: - manager = Instance(manager_info, [Instance(self.model_classdef.info, [])]) - self.add_new_node_to_model_class(manager_name, manager) + manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])]) + self.add_new_node_to_model_class(manager_name, manager_type) else: # create new MODELNAME_MANAGERCLASSNAME class that represents manager parametrized with current model has_manager_any_base = any(self._is_manager_any(base) for base in manager_info.bases) diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index 8f15283..5b6d1b9 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -1,59 +1,39 @@ from collections import OrderedDict -from typing import List, Optional, Sequence, Tuple, Type, Union +from typing import List, Optional, Sequence, Type, Union from django.core.exceptions import FieldError from django.db.models.base import Model from mypy.nodes import Expression, NameExpr -from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext -from mypy.types import AnyType, Instance -from mypy.types import Type as MypyType -from mypy.types import TypeOfAny +from mypy.plugin import FunctionContext, MethodContext +from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.lib import fullnames, helpers -def set_first_generic_param_as_default_for_second(ctx: AnalyzeTypeContext, fullname: str) -> MypyType: - info = helpers.lookup_fully_qualified_typeinfo(ctx.api.api, fullname) - if info is None: - if not ctx.api.api.final_iteration: - ctx.api.api.defer() - - if not ctx.type.args: - return Instance(info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)]) - - args = ctx.type.args - if len(args) == 1: - args = [args[0], args[0]] - - analyzed_args = [ctx.api.analyze_type(arg) for arg in args] - return Instance(info, analyzed_args) - - def determine_proper_manager_type(ctx: FunctionContext) -> MypyType: - ret = ctx.default_return_type - assert isinstance(ret, Instance) + default_return_type = ctx.default_return_type + assert isinstance(default_return_type, Instance) - if not ctx.api.tscope.classes: - # not in class - return ret - outer_model_info = ctx.api.tscope.classes[0] - if not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME): - return ret + outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() + if (outer_model_info is None + or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)): + return default_return_type - return helpers.reparametrize_instance(ret, [Instance(outer_model_info, [])]) + return helpers.reparametrize_instance(default_return_type, [Instance(outer_model_info, [])]) -def get_lookup_field_get_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], - lookup: str, method: str) -> Optional[Tuple[str, MypyType]]: +def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], + *, method: str, lookup: str) -> Optional[MypyType]: try: lookup_field = django_context.lookups_context.resolve_lookup(model_cls, lookup) except FieldError as exc: ctx.api.fail(exc.args[0], ctx.context) return None - field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, method) - return lookup, field_get_type + field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx), + lookup_field, method=method) + return field_get_type def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], @@ -62,18 +42,21 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, if field_lookups is None: return AnyType(TypeOfAny.from_error) + typechecker_api = helpers.get_typechecker_api(ctx) if len(field_lookups) == 0: if flat: primary_key_field = django_context.get_primary_key_field(model_cls) - _, column_type = get_lookup_field_get_type(ctx, django_context, model_cls, - primary_key_field.attname, 'values_list') - return column_type + lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls, + lookup=primary_key_field.attname, method='values_list') + assert lookup_type is not None + return lookup_type elif named: - column_types = OrderedDict() + column_types: 'OrderedDict[str, MypyType]' = OrderedDict() for field in django_context.get_model_fields(model_cls): - column_type = django_context.fields_context.get_field_get_type(ctx.api, field, 'values_list') + column_type = django_context.fields_context.get_field_get_type(typechecker_api, field, + method='values_list') column_types[field.attname] = column_type - return helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types) + return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) else: # flat=False, named=False, all fields field_lookups = [] @@ -81,32 +64,32 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, field_lookups.append(field.attname) if len(field_lookups) > 1 and flat: - ctx.api.fail("'flat' is not valid when 'values_list' is called with more than one field", ctx.context) + typechecker_api.fail("'flat' is not valid when 'values_list' is called with more than one field", ctx.context) return AnyType(TypeOfAny.from_error) column_types = OrderedDict() for field_lookup in field_lookups: - result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values_list') - if result is None: + lookup_field_type = get_field_type_from_lookup(ctx, django_context, model_cls, + lookup=field_lookup, method='values_list') + if lookup_field_type is None: return AnyType(TypeOfAny.from_error) - - column_name, column_type = result - column_types[column_name] = column_type + column_types[field_lookup] = lookup_field_type if flat: assert len(column_types) == 1 row_type = next(iter(column_types.values())) elif named: - row_type = helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types) + row_type = helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) else: - row_type = helpers.make_tuple(ctx.api, list(column_types.values())) + row_type = helpers.make_tuple(typechecker_api, list(column_types.values())) return row_type def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: DjangoContext) -> MypyType: - # called on the Instance + # called on the Instance, returns QuerySet of something assert isinstance(ctx.type, Instance) + assert isinstance(ctx.default_return_type, Instance) # bail if queryset of Any or other non-instances if not isinstance(ctx.type.args[0], Instance): @@ -133,6 +116,10 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: ctx.api.fail("'flat' and 'named' can't be used together", ctx.context) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)]) + # account for possible None + flat = flat or False + named = named or False + row_type = get_values_list_row_type(ctx, django_context, model_cls, flat=flat, named=named) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) @@ -150,8 +137,10 @@ def resolve_field_lookups(lookup_exprs: Sequence[Expression], ctx: Union[Functio def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType: - # queryset method + # called on QuerySet, return QuerySet of something assert isinstance(ctx.type, Instance) + assert isinstance(ctx.default_return_type, Instance) + # if queryset of non-instance type if not isinstance(ctx.type.args[0], Instance): return AnyType(TypeOfAny.from_omitted_generics) @@ -169,14 +158,14 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan for field in django_context.get_model_fields(model_cls): field_lookups.append(field.attname) - column_types = OrderedDict() + column_types: 'OrderedDict[str, MypyType]' = OrderedDict() for field_lookup in field_lookups: - result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values') - if result is None: + field_lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls, + lookup=field_lookup, method='values') + if field_lookup_type is None: return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)]) - column_name, column_type = result - column_types[column_name] = column_type + column_types[field_lookup] = field_lookup_type row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys())) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) diff --git a/mypy_django_plugin/transformers/request.py b/mypy_django_plugin/transformers/request.py index 4ef3b8a..be584ab 100644 --- a/mypy_django_plugin/transformers/request.py +++ b/mypy_django_plugin/transformers/request.py @@ -9,7 +9,7 @@ from mypy_django_plugin.lib import helpers def set_auth_user_model_as_type_for_request_user(ctx: AttributeContext, django_context: DjangoContext) -> MypyType: auth_user_model = django_context.settings.AUTH_USER_MODEL model_cls = django_context.apps_registry.get_model(auth_user_model) - model_info = helpers.lookup_class_typeinfo(ctx.api, model_cls) + model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls) if model_info is None: return ctx.default_attr_type diff --git a/mypy_django_plugin/transformers/settings.py b/mypy_django_plugin/transformers/settings.py index a672e90..f74f213 100644 --- a/mypy_django_plugin/transformers/settings.py +++ b/mypy_django_plugin/transformers/settings.py @@ -1,3 +1,6 @@ +from typing import cast + +from mypy.checker import TypeChecker from mypy.nodes import MemberExpr, TypeInfo from mypy.plugin import AttributeContext, FunctionContext from mypy.types import Instance @@ -13,7 +16,8 @@ def get_user_model_hook(ctx: FunctionContext, django_context: DjangoContext) -> model_cls = django_context.apps_registry.get_model(auth_user_model) model_cls_fullname = helpers.get_class_fullname(model_cls) - model_info = helpers.lookup_fully_qualified_generic(model_cls_fullname, ctx.api.modules) + model_info = helpers.lookup_fully_qualified_generic(model_cls_fullname, + helpers.get_typechecker_api(ctx).modules) assert isinstance(model_info, TypeInfo) return TypeType(Instance(model_info, [])) @@ -26,9 +30,11 @@ def get_type_of_settings_attribute(ctx: AttributeContext, django_context: Django ctx.api.fail(f"'Settings' object has no attribute {setting_name!r}", ctx.context) return ctx.default_attr_type + typechecker_api = helpers.get_typechecker_api(ctx) + # first look for the setting in the project settings file, then global settings - settings_module = ctx.api.modules.get(django_context.django_settings_module) - global_settings_module = ctx.api.modules.get('django.conf.global_settings') + settings_module = typechecker_api.modules.get(django_context.django_settings_module) + global_settings_module = typechecker_api.modules.get('django.conf.global_settings') for module in [settings_module, global_settings_module]: if module is not None: sym = module.names.get(setting_name) @@ -39,7 +45,7 @@ def get_type_of_settings_attribute(ctx: AttributeContext, django_context: Django value = getattr(django_context.settings, setting_name) value_fullname = helpers.get_class_fullname(value.__class__) - value_info = helpers.lookup_fully_qualified_typeinfo(ctx.api, value_fullname) + value_info = helpers.lookup_fully_qualified_typeinfo(typechecker_api, value_fullname) if value_info is None: return ctx.default_attr_type diff --git a/test-data/typecheck/managers/querysets/test_values_list.yml b/test-data/typecheck/managers/querysets/test_values_list.yml index 26cc3da..76d29ff 100644 --- a/test-data/typecheck/managers/querysets/test_values_list.yml +++ b/test-data/typecheck/managers/querysets/test_values_list.yml @@ -97,6 +97,11 @@ pk_values = MyUser.objects.values_list('pk', named=True).get() reveal_type(pk_values) # N: Revealed type is 'Tuple[builtins.int, fallback=main.Row2]' reveal_type(pk_values.pk) # N: # N: Revealed type is 'builtins.int' + + # values_list(named=True) inside function + def func() -> None: + from myapp.models import MyUser + reveal_type(MyUser.objects.values_list('name', named=True).get()) # N: Revealed type is 'Tuple[builtins.str, fallback=main.Row3]' installed_apps: - myapp files: