fix mypy errors

This commit is contained in:
Maxim Kurnikov
2019-07-25 18:52:51 +03:00
parent 6466c57c69
commit 4c21855641
12 changed files with 168 additions and 118 deletions

View File

@@ -45,9 +45,9 @@ class RelatedField(FieldCacheMixin, Field[_ST, _GT]):
one_to_one: bool = ... one_to_one: bool = ...
many_to_many: bool = ... many_to_many: bool = ...
many_to_one: bool = ... many_to_one: bool = ...
@property
def related_model(self) -> Union[Type[Model], str]: ...
opts: Any = ... opts: Any = ...
@property
def related_model(self) -> Type[Model]: ...
def get_forward_related_filter(self, obj: Model) -> Dict[str, Union[int, UUID]]: ... def get_forward_related_filter(self, obj: Model) -> Dict[str, Union[int, UUID]]: ...
def get_reverse_related_filter(self, obj: Model) -> Q: ... def get_reverse_related_filter(self, obj: Model) -> Q: ...
@property @property

View File

@@ -3,6 +3,8 @@ from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Type
from mypy.nodes import TypeInfo
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models.base import Model from django.db.models.base import Model
@@ -12,7 +14,7 @@ from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
from django.utils.functional import cached_property from django.utils.functional import cached_property
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.types import Instance from mypy.types import Instance, AnyType, TypeOfAny
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy_django_plugin.lib import helpers from mypy_django_plugin.lib import helpers
@@ -42,14 +44,14 @@ def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']:
from django.db import models from django.db import models
models.QuerySet.__class_getitem__ = classmethod(noop_class_getitem) models.QuerySet.__class_getitem__ = classmethod(noop_class_getitem) # type: ignore
models.Manager.__class_getitem__ = classmethod(noop_class_getitem) models.Manager.__class_getitem__ = classmethod(noop_class_getitem) # type: ignore
from django.conf import settings from django.conf import settings
from django.apps import apps from django.apps import apps
apps.get_models.cache_clear() apps.get_models.cache_clear() # type: ignore
apps.get_swappable_settings_name.cache_clear() apps.get_swappable_settings_name.cache_clear() # type: ignore
if not settings.configured: if not settings.configured:
settings._setup() settings._setup()
@@ -84,28 +86,37 @@ class DjangoFieldsContext:
return True return True
return nullable return nullable
def get_field_set_type(self, api: TypeChecker, field: Field, method: str) -> MypyType: def get_field_set_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType:
""" Get a type of __set__ for this specific Django field. """
target_field = field target_field = field
if isinstance(field, ForeignKey): if isinstance(field, ForeignKey):
target_field = field.target_field target_field = field.target_field
field_info = helpers.lookup_class_typeinfo(api, target_field.__class__) field_info = helpers.lookup_class_typeinfo(api, target_field.__class__)
if field_info is None:
return AnyType(TypeOfAny.from_error)
field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type', field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
is_nullable=self.get_field_nullability(field, method)) is_nullable=self.get_field_nullability(field, method))
if isinstance(target_field, ArrayField): if isinstance(target_field, ArrayField):
argument_field_type = self.get_field_set_type(api, target_field.base_field, method) argument_field_type = self.get_field_set_type(api, target_field.base_field, method=method)
field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type) field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type)
return field_set_type return field_set_type
def get_field_get_type(self, api: TypeChecker, field: Field, method: str) -> MypyType: def get_field_get_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType:
""" Get a type of __get__ for this specific Django field. """
field_info = helpers.lookup_class_typeinfo(api, field.__class__) field_info = helpers.lookup_class_typeinfo(api, field.__class__)
assert isinstance(field_info, TypeInfo)
is_nullable = self.get_field_nullability(field, method) is_nullable = self.get_field_nullability(field, method)
if isinstance(field, RelatedField): if isinstance(field, RelatedField):
if method == 'values': if method == 'values':
primary_key_field = self.django_context.get_primary_key_field(field.related_model) primary_key_field = self.django_context.get_primary_key_field(field.related_model)
return self.get_field_get_type(api, primary_key_field, method) return self.get_field_get_type(api, primary_key_field, method=method)
model_info = helpers.lookup_class_typeinfo(api, field.related_model) model_info = helpers.lookup_class_typeinfo(api, field.related_model)
assert isinstance(model_info, TypeInfo)
return Instance(model_info, []) return Instance(model_info, [])
else: else:
return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
@@ -136,6 +147,8 @@ class DjangoLookupsContext:
if isinstance(current_field, RelatedField): if isinstance(current_field, RelatedField):
currently_observed_model = current_field.related_model currently_observed_model = current_field.related_model
# if it is None, solve_lookup_type() will fail earlier
assert current_field is not None
return current_field return current_field
@@ -146,9 +159,6 @@ class DjangoContext:
self.django_settings_module = django_settings_module self.django_settings_module = django_settings_module
self.apps_registry: Optional[Dict[str, str]] = None
self.settings: LazySettings = None
if self.django_settings_module:
apps, settings = initialize_django(self.django_settings_module) apps, settings = initialize_django(self.django_settings_module)
self.apps_registry = apps self.apps_registry = apps
self.settings = settings self.settings = settings
@@ -170,6 +180,7 @@ class DjangoContext:
for model_cls in self.model_modules.get(module, []): for model_cls in self.model_modules.get(module, []):
if model_cls.__name__ == model_cls_name: if model_cls.__name__ == model_cls_name:
return model_cls return model_cls
return None
def get_model_fields(self, model_cls: Type[Model]) -> Iterator[Field]: def get_model_fields(self, model_cls: Type[Model]) -> Iterator[Field]:
for field in model_cls._meta.get_fields(): for field in model_cls._meta.get_fields():
@@ -188,30 +199,33 @@ class DjangoContext:
return field return field
raise ValueError('No primary key defined') raise ValueError('No primary key defined')
def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], method: str) -> Dict[str, MypyType]: def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], *, method: str) -> Dict[str, MypyType]:
from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.fields import GenericForeignKey
expected_types = {} expected_types = {}
# add pk # add pk
primary_key_field = self.get_primary_key_field(model_cls) primary_key_field = self.get_primary_key_field(model_cls)
field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method) field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method)
expected_types['pk'] = field_set_type expected_types['pk'] = field_set_type
for field in model_cls._meta.get_fields(): for field in model_cls._meta.get_fields():
if isinstance(field, Field): if isinstance(field, Field):
field_name = field.attname field_name = field.attname
field_set_type = self.fields_context.get_field_set_type(api, field, method) field_set_type = self.fields_context.get_field_set_type(api, field, method=method)
expected_types[field_name] = field_set_type expected_types[field_name] = field_set_type
if isinstance(field, ForeignKey): if isinstance(field, ForeignKey):
field_name = field.name field_name = field.name
foreign_key_info = helpers.lookup_class_typeinfo(api, field.__class__) foreign_key_info = helpers.lookup_class_typeinfo(api, field.__class__)
assert isinstance(foreign_key_info, TypeInfo)
related_model = field.related_model related_model = field.related_model
if related_model._meta.proxy_for_model: if related_model._meta.proxy_for_model:
related_model = field.related_model._meta.proxy_for_model related_model = field.related_model._meta.proxy_for_model
related_model_info = helpers.lookup_class_typeinfo(api, related_model) related_model_info = helpers.lookup_class_typeinfo(api, related_model)
assert isinstance(related_model_info, TypeInfo)
is_nullable = self.fields_context.get_field_nullability(field, method) is_nullable = self.fields_context.get_field_nullability(field, method)
foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info, foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info,
'_pyi_private_set_type', '_pyi_private_set_type',

View File

@@ -1,17 +1,13 @@
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union, cast
from mypy import checker from mypy import checker
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.mro import calculate_mro from mypy.mro import calculate_mro
from mypy.nodes import ( from mypy.nodes import (Block, ClassDef, Expression, GDEF, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode,
GDEF, MDEF, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, SymbolTable, SymbolTableNode, TypeInfo, Var)
SymbolTableNode, TypeInfo, Var, from mypy.plugin import AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext
) from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext
from mypy.types import AnyType, Instance, NoneTyp, TupleType
from mypy.types import Type as MypyType
from mypy.types import TypedDictType, TypeOfAny, UnionType
if TYPE_CHECKING: if TYPE_CHECKING:
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
@@ -119,9 +115,17 @@ def has_any_of_bases(info: TypeInfo, bases: Set[str]) -> bool:
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType: def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType:
node = type_info.get(private_field_name).node """ Return declared type of type_info's private_field_name (used for private Field attributes)"""
sym = type_info.get(private_field_name)
if sym is None:
return AnyType(TypeOfAny.unannotated)
node = sym.node
if isinstance(node, Var): if isinstance(node, Var):
descriptor_type = node.type descriptor_type = node.type
if descriptor_type is None:
return AnyType(TypeOfAny.unannotated)
if is_nullable: if is_nullable:
descriptor_type = make_optional(descriptor_type) descriptor_type = make_optional(descriptor_type)
return descriptor_type return descriptor_type
@@ -167,8 +171,18 @@ def add_new_class_for_module(module: MypyFile, name: str, bases: List[Instance],
return new_typeinfo return new_typeinfo
def get_current_module(api: TypeChecker) -> MypyFile:
current_module = None
for item in reversed(api.scope.stack):
if isinstance(item, MypyFile):
current_module = item
break
assert current_module is not None
return current_module
def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]') -> TupleType: def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]') -> TupleType:
current_module = api.scope.stack[0] current_module = get_current_module(api)
namedtuple_info = add_new_class_for_module(current_module, name, namedtuple_info = add_new_class_for_module(current_module, name,
bases=[api.named_generic_type('typing.NamedTuple', [])], bases=[api.named_generic_type('typing.NamedTuple', [])],
fields=fields) fields=fields)
@@ -225,3 +239,9 @@ def resolve_string_attribute_value(attr_expr: Expression, ctx: Union[FunctionCon
ctx.api.fail(f'Expression of type {type(attr_expr).__name__!r} is not supported', ctx.context) ctx.api.fail(f'Expression of type {type(attr_expr).__name__!r} is not supported', ctx.context)
return None return None
def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker:
if not isinstance(ctx.api, TypeChecker):
raise ValueError('Not a TypeChecker')
return cast(TypeChecker, ctx.api)

View File

@@ -54,7 +54,7 @@ def extract_django_settings_module(config_file_path: Optional[str]) -> str:
errors.raise_error() errors.raise_error()
parser = configparser.ConfigParser() parser = configparser.ConfigParser()
parser.read(config_file_path) parser.read(config_file_path) # type: ignore
if not parser.has_section('mypy.plugins.django-stubs'): if not parser.has_section('mypy.plugins.django-stubs'):
errors.report(0, None, "'django_settings_module' is not set: no section [mypy.plugins.django-stubs]", errors.report(0, None, "'django_settings_module' is not set: no section [mypy.plugins.django-stubs]",
@@ -174,6 +174,7 @@ class NewSemanalDjangoPlugin(Plugin):
if info.has_base(fullnames.MODEL_CLASS_FULLNAME): if info.has_base(fullnames.MODEL_CLASS_FULLNAME):
return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context) return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)
return None
def get_method_hook(self, fullname: str def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], MypyType]]: ) -> Optional[Callable[[MethodContext], MypyType]]:
@@ -206,6 +207,7 @@ class NewSemanalDjangoPlugin(Plugin):
manager_classes = self._get_current_manager_bases() manager_classes = self._get_current_manager_bases()
if class_fullname in manager_classes and method_name == 'create': if class_fullname in manager_classes and method_name == 'create':
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context) return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
return None
def get_base_class_hook(self, fullname: str def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]: ) -> Optional[Callable[[ClassDefContext], None]]:
@@ -217,6 +219,7 @@ class NewSemanalDjangoPlugin(Plugin):
if fullname in self._get_current_form_bases(): if fullname in self._get_current_form_bases():
return transform_form_class return transform_form_class
return None
def get_attribute_hook(self, fullname: str def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], MypyType]]: ) -> Optional[Callable[[AttributeContext], MypyType]]:
@@ -228,12 +231,7 @@ class NewSemanalDjangoPlugin(Plugin):
info = self._get_typeinfo_or_none(class_name) info = self._get_typeinfo_or_none(class_name)
if info and info.has_base(fullnames.HTTPREQUEST_CLASS_FULLNAME) and attr_name == 'user': if info and info.has_base(fullnames.HTTPREQUEST_CLASS_FULLNAME) and attr_name == 'user':
return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context) return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context)
return None
# def get_type_analyze_hook(self, fullname: str
# ( ):
# info = self._get_typeinfo_or_none(fullname)
# if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
# return partial(querysets.set_first_generic_param_as_default_for_second, fullname=fullname)
def plugin(version): def plugin(version):

View File

@@ -2,7 +2,7 @@ from typing import Optional, Tuple, cast
from django.db.models.fields import Field from django.db.models.fields import Field
from django.db.models.fields.related import RelatedField from django.db.models.fields.related import RelatedField
from mypy.nodes import AssignmentStmt, TypeInfo from mypy.nodes import AssignmentStmt, TypeInfo, NameExpr
from mypy.plugin import FunctionContext from mypy.plugin import FunctionContext
from mypy.types import AnyType, Instance from mypy.types import AnyType, Instance
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
@@ -13,15 +13,16 @@ from mypy_django_plugin.lib import fullnames, helpers
def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]: def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]:
outer_model_info = ctx.api.scope.active_class() outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
assert isinstance(outer_model_info, TypeInfo) if (outer_model_info is None
if not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME): or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)):
return None return None
field_name = None field_name = None
for stmt in outer_model_info.defn.defs.body: for stmt in outer_model_info.defn.defs.body:
if isinstance(stmt, AssignmentStmt): if isinstance(stmt, AssignmentStmt):
if stmt.rvalue == ctx.context: if stmt.rvalue == ctx.context:
assert isinstance(stmt.lvalues[0], NameExpr)
field_name = stmt.lvalues[0].name field_name = stmt.lvalues[0].name
break break
if field_name is None: if field_name is None:
@@ -46,8 +47,13 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
if related_model_to_set._meta.proxy_for_model: if related_model_to_set._meta.proxy_for_model:
related_model_to_set = related_model._meta.proxy_for_model related_model_to_set = related_model._meta.proxy_for_model
related_model_info = helpers.lookup_class_typeinfo(ctx.api, related_model) typechecker_api = helpers.get_typechecker_api(ctx)
related_model_to_set_info = helpers.lookup_class_typeinfo(ctx.api, related_model_to_set)
related_model_info = helpers.lookup_class_typeinfo(typechecker_api, related_model)
assert isinstance(related_model_info, TypeInfo)
related_model_to_set_info = helpers.lookup_class_typeinfo(typechecker_api, related_model_to_set)
assert isinstance(related_model_to_set_info, TypeInfo)
default_related_field_type = set_descriptor_types_for_field(ctx) default_related_field_type = set_descriptor_types_for_field(ctx)
# replace Any with referred_to_type # replace Any with referred_to_type
@@ -68,7 +74,12 @@ def get_field_descriptor_types(field_info: TypeInfo, is_nullable: bool) -> Tuple
def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance: def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
default_return_type = cast(Instance, ctx.default_return_type) default_return_type = cast(Instance, ctx.default_return_type)
is_nullable = helpers.parse_bool(helpers.get_call_argument_by_name(ctx, 'null'))
is_nullable = False
null_expr = helpers.get_call_argument_by_name(ctx, 'null')
if null_expr is not None:
is_nullable = helpers.parse_bool(null_expr) or False
set_type, get_type = get_field_descriptor_types(default_return_type.type, is_nullable) set_type, get_type = get_field_descriptor_types(default_return_type.type, is_nullable)
return helpers.reparametrize_instance(default_return_type, [set_type, get_type]) return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
@@ -92,7 +103,7 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan
default_return_type = ctx.default_return_type default_return_type = ctx.default_return_type
assert isinstance(default_return_type, Instance) assert isinstance(default_return_type, Instance)
outer_model_info = ctx.api.scope.active_class() outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
if not outer_model_info or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME): if not outer_model_info or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
# not inside models.Model class # not inside models.Model class
return ctx.default_return_type return ctx.default_return_type

View File

@@ -2,10 +2,10 @@ from typing import List, Tuple, Type, Union
from django.db.models.base import Model from django.db.models.base import Model
from mypy.plugin import FunctionContext, MethodContext from mypy.plugin import FunctionContext, MethodContext
from mypy.types import Instance from mypy.types import Instance, Type as MypyType
from mypy.types import Type as MypyType
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import helpers
def get_actual_types(ctx: Union[MethodContext, FunctionContext], def get_actual_types(ctx: Union[MethodContext, FunctionContext],
@@ -31,7 +31,8 @@ def get_actual_types(ctx: Union[MethodContext, FunctionContext],
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext, def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
model_cls: Type[Model], method: str) -> MypyType: model_cls: Type[Model], method: str) -> MypyType:
expected_types = django_context.get_expected_types(ctx.api, model_cls, method) typechecker_api = helpers.get_typechecker_api(ctx)
expected_types = django_context.get_expected_types(typechecker_api, model_cls, method=method)
expected_keys = [key for key in expected_types.keys() if key != 'pk'] expected_keys = [key for key in expected_types.keys() if key != 'pk']
for actual_name, actual_type in get_actual_types(ctx, expected_keys): for actual_name, actual_type in get_actual_types(ctx, expected_keys):
@@ -40,7 +41,7 @@ def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_co
model_cls.__name__), model_cls.__name__),
ctx.context) ctx.context)
continue continue
ctx.api.check_subtype(actual_type, expected_types[actual_name], typechecker_api.check_subtype(actual_type, expected_types[actual_name],
ctx.context, ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name, 'Incompatible type for "{}" of "{}"'.format(actual_name,
model_cls.__name__), model_cls.__name__),

View File

@@ -1,4 +1,5 @@
from django.core.exceptions import FieldDoesNotExist from django.core.exceptions import FieldDoesNotExist
from mypy.nodes import TypeInfo
from mypy.plugin import MethodContext from mypy.plugin import MethodContext
from mypy.types import AnyType, Instance from mypy.types import AnyType, Instance
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
@@ -9,11 +10,16 @@ from mypy_django_plugin.lib import fullnames, helpers
def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType: def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType:
field_info = helpers.lookup_fully_qualified_typeinfo(ctx.api, field_fullname) field_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx),
field_fullname)
assert isinstance(field_info, TypeInfo)
return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)]) return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)])
def return_proper_field_type_from_get_field(ctx: MethodContext, django_context: DjangoContext) -> MypyType: def return_proper_field_type_from_get_field(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
# Options instance
assert isinstance(ctx.type, Instance)
model_type = ctx.type.args[0] model_type = ctx.type.args[0]
if not isinstance(model_type, Instance): if not isinstance(model_type, Instance):
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME) return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)

View File

@@ -118,8 +118,8 @@ class AddManagers(ModelClassInitializer):
manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname) manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname)
if manager_name not in self.model_classdef.info.names: if manager_name not in self.model_classdef.info.names:
manager = Instance(manager_info, [Instance(self.model_classdef.info, [])]) manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])])
self.add_new_node_to_model_class(manager_name, manager) self.add_new_node_to_model_class(manager_name, manager_type)
else: else:
# create new MODELNAME_MANAGERCLASSNAME class that represents manager parametrized with current model # create new MODELNAME_MANAGERCLASSNAME class that represents manager parametrized with current model
has_manager_any_base = any(self._is_manager_any(base) for base in manager_info.bases) has_manager_any_base = any(self._is_manager_any(base) for base in manager_info.bases)

View File

@@ -1,59 +1,39 @@
from collections import OrderedDict from collections import OrderedDict
from typing import List, Optional, Sequence, Tuple, Type, Union from typing import List, Optional, Sequence, Type, Union
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models.base import Model from django.db.models.base import Model
from mypy.nodes import Expression, NameExpr from mypy.nodes import Expression, NameExpr
from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib import fullnames, helpers
def set_first_generic_param_as_default_for_second(ctx: AnalyzeTypeContext, fullname: str) -> MypyType:
info = helpers.lookup_fully_qualified_typeinfo(ctx.api.api, fullname)
if info is None:
if not ctx.api.api.final_iteration:
ctx.api.api.defer()
if not ctx.type.args:
return Instance(info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)])
args = ctx.type.args
if len(args) == 1:
args = [args[0], args[0]]
analyzed_args = [ctx.api.analyze_type(arg) for arg in args]
return Instance(info, analyzed_args)
def determine_proper_manager_type(ctx: FunctionContext) -> MypyType: def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
ret = ctx.default_return_type default_return_type = ctx.default_return_type
assert isinstance(ret, Instance) assert isinstance(default_return_type, Instance)
if not ctx.api.tscope.classes: outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
# not in class if (outer_model_info is None
return ret or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)):
outer_model_info = ctx.api.tscope.classes[0] return default_return_type
if not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
return ret
return helpers.reparametrize_instance(ret, [Instance(outer_model_info, [])]) return helpers.reparametrize_instance(default_return_type, [Instance(outer_model_info, [])])
def get_lookup_field_get_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
lookup: str, method: str) -> Optional[Tuple[str, MypyType]]: *, method: str, lookup: str) -> Optional[MypyType]:
try: try:
lookup_field = django_context.lookups_context.resolve_lookup(model_cls, lookup) lookup_field = django_context.lookups_context.resolve_lookup(model_cls, lookup)
except FieldError as exc: except FieldError as exc:
ctx.api.fail(exc.args[0], ctx.context) ctx.api.fail(exc.args[0], ctx.context)
return None return None
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, method) field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx),
return lookup, field_get_type lookup_field, method=method)
return field_get_type
def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
@@ -62,18 +42,21 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
if field_lookups is None: if field_lookups is None:
return AnyType(TypeOfAny.from_error) return AnyType(TypeOfAny.from_error)
typechecker_api = helpers.get_typechecker_api(ctx)
if len(field_lookups) == 0: if len(field_lookups) == 0:
if flat: if flat:
primary_key_field = django_context.get_primary_key_field(model_cls) primary_key_field = django_context.get_primary_key_field(model_cls)
_, column_type = get_lookup_field_get_type(ctx, django_context, model_cls, lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls,
primary_key_field.attname, 'values_list') lookup=primary_key_field.attname, method='values_list')
return column_type assert lookup_type is not None
return lookup_type
elif named: elif named:
column_types = OrderedDict() column_types: 'OrderedDict[str, MypyType]' = OrderedDict()
for field in django_context.get_model_fields(model_cls): for field in django_context.get_model_fields(model_cls):
column_type = django_context.fields_context.get_field_get_type(ctx.api, field, 'values_list') column_type = django_context.fields_context.get_field_get_type(typechecker_api, field,
method='values_list')
column_types[field.attname] = column_type column_types[field.attname] = column_type
return helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types) return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
else: else:
# flat=False, named=False, all fields # flat=False, named=False, all fields
field_lookups = [] field_lookups = []
@@ -81,32 +64,32 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
field_lookups.append(field.attname) field_lookups.append(field.attname)
if len(field_lookups) > 1 and flat: if len(field_lookups) > 1 and flat:
ctx.api.fail("'flat' is not valid when 'values_list' is called with more than one field", ctx.context) typechecker_api.fail("'flat' is not valid when 'values_list' is called with more than one field", ctx.context)
return AnyType(TypeOfAny.from_error) return AnyType(TypeOfAny.from_error)
column_types = OrderedDict() column_types = OrderedDict()
for field_lookup in field_lookups: for field_lookup in field_lookups:
result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values_list') lookup_field_type = get_field_type_from_lookup(ctx, django_context, model_cls,
if result is None: lookup=field_lookup, method='values_list')
if lookup_field_type is None:
return AnyType(TypeOfAny.from_error) return AnyType(TypeOfAny.from_error)
column_types[field_lookup] = lookup_field_type
column_name, column_type = result
column_types[column_name] = column_type
if flat: if flat:
assert len(column_types) == 1 assert len(column_types) == 1
row_type = next(iter(column_types.values())) row_type = next(iter(column_types.values()))
elif named: elif named:
row_type = helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types) row_type = helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
else: else:
row_type = helpers.make_tuple(ctx.api, list(column_types.values())) row_type = helpers.make_tuple(typechecker_api, list(column_types.values()))
return row_type return row_type
def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: DjangoContext) -> MypyType: def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
# called on the Instance # called on the Instance, returns QuerySet of something
assert isinstance(ctx.type, Instance) assert isinstance(ctx.type, Instance)
assert isinstance(ctx.default_return_type, Instance)
# bail if queryset of Any or other non-instances # bail if queryset of Any or other non-instances
if not isinstance(ctx.type.args[0], Instance): if not isinstance(ctx.type.args[0], Instance):
@@ -133,6 +116,10 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
ctx.api.fail("'flat' and 'named' can't be used together", ctx.context) ctx.api.fail("'flat' and 'named' can't be used together", ctx.context)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)]) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
# account for possible None
flat = flat or False
named = named or False
row_type = get_values_list_row_type(ctx, django_context, model_cls, row_type = get_values_list_row_type(ctx, django_context, model_cls,
flat=flat, named=named) flat=flat, named=named)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
@@ -150,8 +137,10 @@ def resolve_field_lookups(lookup_exprs: Sequence[Expression], ctx: Union[Functio
def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType: def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
# queryset method # called on QuerySet, return QuerySet of something
assert isinstance(ctx.type, Instance) assert isinstance(ctx.type, Instance)
assert isinstance(ctx.default_return_type, Instance)
# if queryset of non-instance type # if queryset of non-instance type
if not isinstance(ctx.type.args[0], Instance): if not isinstance(ctx.type.args[0], Instance):
return AnyType(TypeOfAny.from_omitted_generics) return AnyType(TypeOfAny.from_omitted_generics)
@@ -169,14 +158,14 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
for field in django_context.get_model_fields(model_cls): for field in django_context.get_model_fields(model_cls):
field_lookups.append(field.attname) field_lookups.append(field.attname)
column_types = OrderedDict() column_types: 'OrderedDict[str, MypyType]' = OrderedDict()
for field_lookup in field_lookups: for field_lookup in field_lookups:
result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values') field_lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls,
if result is None: lookup=field_lookup, method='values')
if field_lookup_type is None:
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)]) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
column_name, column_type = result column_types[field_lookup] = field_lookup_type
column_types[column_name] = column_type
row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys())) row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys()))
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])

View File

@@ -9,7 +9,7 @@ from mypy_django_plugin.lib import helpers
def set_auth_user_model_as_type_for_request_user(ctx: AttributeContext, django_context: DjangoContext) -> MypyType: def set_auth_user_model_as_type_for_request_user(ctx: AttributeContext, django_context: DjangoContext) -> MypyType:
auth_user_model = django_context.settings.AUTH_USER_MODEL auth_user_model = django_context.settings.AUTH_USER_MODEL
model_cls = django_context.apps_registry.get_model(auth_user_model) model_cls = django_context.apps_registry.get_model(auth_user_model)
model_info = helpers.lookup_class_typeinfo(ctx.api, model_cls) model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls)
if model_info is None: if model_info is None:
return ctx.default_attr_type return ctx.default_attr_type

View File

@@ -1,3 +1,6 @@
from typing import cast
from mypy.checker import TypeChecker
from mypy.nodes import MemberExpr, TypeInfo from mypy.nodes import MemberExpr, TypeInfo
from mypy.plugin import AttributeContext, FunctionContext from mypy.plugin import AttributeContext, FunctionContext
from mypy.types import Instance from mypy.types import Instance
@@ -13,7 +16,8 @@ def get_user_model_hook(ctx: FunctionContext, django_context: DjangoContext) ->
model_cls = django_context.apps_registry.get_model(auth_user_model) model_cls = django_context.apps_registry.get_model(auth_user_model)
model_cls_fullname = helpers.get_class_fullname(model_cls) model_cls_fullname = helpers.get_class_fullname(model_cls)
model_info = helpers.lookup_fully_qualified_generic(model_cls_fullname, ctx.api.modules) model_info = helpers.lookup_fully_qualified_generic(model_cls_fullname,
helpers.get_typechecker_api(ctx).modules)
assert isinstance(model_info, TypeInfo) assert isinstance(model_info, TypeInfo)
return TypeType(Instance(model_info, [])) return TypeType(Instance(model_info, []))
@@ -26,9 +30,11 @@ def get_type_of_settings_attribute(ctx: AttributeContext, django_context: Django
ctx.api.fail(f"'Settings' object has no attribute {setting_name!r}", ctx.context) ctx.api.fail(f"'Settings' object has no attribute {setting_name!r}", ctx.context)
return ctx.default_attr_type return ctx.default_attr_type
typechecker_api = helpers.get_typechecker_api(ctx)
# first look for the setting in the project settings file, then global settings # first look for the setting in the project settings file, then global settings
settings_module = ctx.api.modules.get(django_context.django_settings_module) settings_module = typechecker_api.modules.get(django_context.django_settings_module)
global_settings_module = ctx.api.modules.get('django.conf.global_settings') global_settings_module = typechecker_api.modules.get('django.conf.global_settings')
for module in [settings_module, global_settings_module]: for module in [settings_module, global_settings_module]:
if module is not None: if module is not None:
sym = module.names.get(setting_name) sym = module.names.get(setting_name)
@@ -39,7 +45,7 @@ def get_type_of_settings_attribute(ctx: AttributeContext, django_context: Django
value = getattr(django_context.settings, setting_name) value = getattr(django_context.settings, setting_name)
value_fullname = helpers.get_class_fullname(value.__class__) value_fullname = helpers.get_class_fullname(value.__class__)
value_info = helpers.lookup_fully_qualified_typeinfo(ctx.api, value_fullname) value_info = helpers.lookup_fully_qualified_typeinfo(typechecker_api, value_fullname)
if value_info is None: if value_info is None:
return ctx.default_attr_type return ctx.default_attr_type

View File

@@ -97,6 +97,11 @@
pk_values = MyUser.objects.values_list('pk', named=True).get() pk_values = MyUser.objects.values_list('pk', named=True).get()
reveal_type(pk_values) # N: Revealed type is 'Tuple[builtins.int, fallback=main.Row2]' reveal_type(pk_values) # N: Revealed type is 'Tuple[builtins.int, fallback=main.Row2]'
reveal_type(pk_values.pk) # N: # N: Revealed type is 'builtins.int' reveal_type(pk_values.pk) # N: # N: Revealed type is 'builtins.int'
# values_list(named=True) inside function
def func() -> None:
from myapp.models import MyUser
reveal_type(MyUser.objects.values_list('name', named=True).get()) # N: Revealed type is 'Tuple[builtins.str, fallback=main.Row3]'
installed_apps: installed_apps:
- myapp - myapp
files: files: