mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-15 08:17:08 +08:00
updated package setup (#485)
* updated package setup * updated to use python 3.9 * fixed test runner * fixed typecheck tests * fixed discrepencies * added override to runner * updated travis * updated pre-commit hooks * updated dep
This commit is contained in:
committed by
GitHub
parent
a3624dec36
commit
44151c485d
@@ -2,9 +2,7 @@ import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
TYPE_CHECKING, Dict, Iterable, Iterator, Optional, Set, Tuple, Type, Union,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import models
|
||||
@@ -26,9 +24,11 @@ from mypy_django_plugin.lib import fullnames, helpers
|
||||
try:
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
except ImportError:
|
||||
|
||||
class ArrayField: # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.apps.registry import Apps # noqa: F401
|
||||
from django.conf import LazySettings # noqa: F401
|
||||
@@ -45,9 +45,9 @@ def temp_environ():
|
||||
os.environ.update(environ)
|
||||
|
||||
|
||||
def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']:
|
||||
def initialize_django(settings_module: str) -> Tuple["Apps", "LazySettings"]:
|
||||
with temp_environ():
|
||||
os.environ['DJANGO_SETTINGS_MODULE'] = settings_module
|
||||
os.environ["DJANGO_SETTINGS_MODULE"] = settings_module
|
||||
|
||||
# add current directory to sys.path
|
||||
sys.path.append(os.getcwd())
|
||||
@@ -60,8 +60,8 @@ def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']:
|
||||
models.QuerySet.__class_getitem__ = classmethod(noop_class_getitem) # type: ignore
|
||||
models.Manager.__class_getitem__ = classmethod(noop_class_getitem) # type: ignore
|
||||
|
||||
from django.conf import settings
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
|
||||
apps.get_models.cache_clear() # type: ignore
|
||||
apps.get_swappable_settings_name.cache_clear() # type: ignore
|
||||
@@ -100,15 +100,13 @@ class DjangoContext:
|
||||
modules[concrete_model_cls.__module__].add(concrete_model_cls)
|
||||
# collect abstract=True models
|
||||
for model_cls in concrete_model_cls.mro()[1:]:
|
||||
if (issubclass(model_cls, Model)
|
||||
and hasattr(model_cls, '_meta')
|
||||
and model_cls._meta.abstract):
|
||||
if issubclass(model_cls, Model) and hasattr(model_cls, "_meta") and model_cls._meta.abstract:
|
||||
modules[model_cls.__module__].add(model_cls)
|
||||
return modules
|
||||
|
||||
def get_model_class_by_fullname(self, fullname: str) -> Optional[Type[Model]]:
|
||||
# Returns None if Model is abstract
|
||||
module, _, model_cls_name = fullname.rpartition('.')
|
||||
module, _, model_cls_name = fullname.rpartition(".")
|
||||
for model_cls in self.model_modules.get(module, set()):
|
||||
if model_cls.__name__ == model_cls_name:
|
||||
return model_cls
|
||||
@@ -128,7 +126,7 @@ class DjangoContext:
|
||||
if isinstance(field, (RelatedField, ForeignObjectRel)):
|
||||
related_model_cls = field.related_model
|
||||
primary_key_field = self.get_primary_key_field(related_model_cls)
|
||||
primary_key_type = self.get_field_get_type(api, primary_key_field, method='init')
|
||||
primary_key_type = self.get_field_get_type(api, primary_key_field, method="init")
|
||||
|
||||
rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
|
||||
if rel_model_info is None:
|
||||
@@ -140,15 +138,14 @@ class DjangoContext:
|
||||
field_info = helpers.lookup_class_typeinfo(api, field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
return helpers.get_private_descriptor_type(field_info, '_pyi_lookup_exact_type',
|
||||
is_nullable=field.null)
|
||||
return helpers.get_private_descriptor_type(field_info, "_pyi_lookup_exact_type", is_nullable=field.null)
|
||||
|
||||
def get_primary_key_field(self, model_cls: Type[Model]) -> Field:
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, Field):
|
||||
if field.primary_key:
|
||||
return field
|
||||
raise ValueError('No primary key defined')
|
||||
raise ValueError("No primary key defined")
|
||||
|
||||
def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], *, method: str) -> Dict[str, MypyType]:
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
@@ -158,7 +155,7 @@ class DjangoContext:
|
||||
if not model_cls._meta.abstract:
|
||||
primary_key_field = self.get_primary_key_field(model_cls)
|
||||
field_set_type = self.get_field_set_type(api, primary_key_field, method=method)
|
||||
expected_types['pk'] = field_set_type
|
||||
expected_types["pk"] = field_set_type
|
||||
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, Field):
|
||||
@@ -188,11 +185,10 @@ class DjangoContext:
|
||||
continue
|
||||
|
||||
is_nullable = self.get_field_nullability(field, method)
|
||||
foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info,
|
||||
'_pyi_private_set_type',
|
||||
is_nullable=is_nullable)
|
||||
model_set_type = helpers.convert_any_to_type(foreign_key_set_type,
|
||||
Instance(related_model_info, []))
|
||||
foreign_key_set_type = helpers.get_private_descriptor_type(
|
||||
foreign_key_info, "_pyi_private_set_type", is_nullable=is_nullable
|
||||
)
|
||||
model_set_type = helpers.convert_any_to_type(foreign_key_set_type, Instance(related_model_info, []))
|
||||
|
||||
expected_types[field_name] = model_set_type
|
||||
|
||||
@@ -200,8 +196,7 @@ class DjangoContext:
|
||||
# it's generic, so cannot set specific model
|
||||
field_name = field.name
|
||||
gfk_info = helpers.lookup_class_typeinfo(api, field.__class__)
|
||||
gfk_set_type = helpers.get_private_descriptor_type(gfk_info, '_pyi_private_set_type',
|
||||
is_nullable=True)
|
||||
gfk_set_type = helpers.get_private_descriptor_type(gfk_info, "_pyi_private_set_type", is_nullable=True)
|
||||
expected_types[field_name] = gfk_set_type
|
||||
|
||||
return expected_types
|
||||
@@ -230,11 +225,10 @@ class DjangoContext:
|
||||
nullable = field.null
|
||||
if not nullable and isinstance(field, CharField) and field.blank:
|
||||
return True
|
||||
if method == '__init__':
|
||||
if ((isinstance(field, Field) and field.primary_key)
|
||||
or isinstance(field, ForeignKey)):
|
||||
if method == "__init__":
|
||||
if (isinstance(field, Field) and field.primary_key) or isinstance(field, ForeignKey):
|
||||
return True
|
||||
if method == 'create':
|
||||
if method == "create":
|
||||
if isinstance(field, AutoField):
|
||||
return True
|
||||
if isinstance(field, Field) and field.has_default():
|
||||
@@ -251,8 +245,9 @@ class DjangoContext:
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
|
||||
is_nullable=self.get_field_nullability(field, method))
|
||||
field_set_type = helpers.get_private_descriptor_type(
|
||||
field_info, "_pyi_private_set_type", is_nullable=self.get_field_nullability(field, method)
|
||||
)
|
||||
if isinstance(target_field, ArrayField):
|
||||
argument_field_type = self.get_field_set_type(api, target_field.base_field, method=method)
|
||||
field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type)
|
||||
@@ -270,7 +265,7 @@ class DjangoContext:
|
||||
if related_model_cls is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
if method == 'values':
|
||||
if method == "values":
|
||||
primary_key_field = self.get_primary_key_field(related_model_cls)
|
||||
return self.get_field_get_type(api, primary_key_field, method=method)
|
||||
|
||||
@@ -280,8 +275,7 @@ class DjangoContext:
|
||||
|
||||
return Instance(model_info, [])
|
||||
else:
|
||||
return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
||||
is_nullable=is_nullable)
|
||||
return helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_nullable)
|
||||
|
||||
def get_field_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) -> Optional[Type[Model]]:
|
||||
if isinstance(field, RelatedField):
|
||||
@@ -290,26 +284,25 @@ class DjangoContext:
|
||||
related_model_cls = field.field.model
|
||||
|
||||
if isinstance(related_model_cls, str):
|
||||
if related_model_cls == 'self':
|
||||
if related_model_cls == "self":
|
||||
# same model
|
||||
related_model_cls = field.model
|
||||
elif '.' not in related_model_cls:
|
||||
elif "." not in related_model_cls:
|
||||
# same file model
|
||||
related_model_fullname = field.model.__module__ + '.' + related_model_cls
|
||||
related_model_fullname = field.model.__module__ + "." + related_model_cls
|
||||
related_model_cls = self.get_model_class_by_fullname(related_model_fullname)
|
||||
else:
|
||||
related_model_cls = self.apps_registry.get_model(related_model_cls)
|
||||
|
||||
return related_model_cls
|
||||
|
||||
def _resolve_field_from_parts(self,
|
||||
field_parts: Iterable[str],
|
||||
model_cls: Type[Model]
|
||||
) -> Union[Field, ForeignObjectRel]:
|
||||
def _resolve_field_from_parts(
|
||||
self, field_parts: Iterable[str], model_cls: Type[Model]
|
||||
) -> Union[Field, ForeignObjectRel]:
|
||||
currently_observed_model = model_cls
|
||||
field = None
|
||||
for field_part in field_parts:
|
||||
if field_part == 'pk':
|
||||
if field_part == "pk":
|
||||
field = self.get_primary_key_field(currently_observed_model)
|
||||
continue
|
||||
|
||||
@@ -317,8 +310,7 @@ class DjangoContext:
|
||||
if isinstance(field, RelatedField):
|
||||
currently_observed_model = field.related_model
|
||||
model_name = currently_observed_model._meta.model_name
|
||||
if (model_name is not None
|
||||
and field_part == (model_name + '_id')):
|
||||
if model_name is not None and field_part == (model_name + "_id"):
|
||||
field = self.get_primary_key_field(currently_observed_model)
|
||||
|
||||
if isinstance(field, ForeignObjectRel):
|
||||
@@ -368,13 +360,13 @@ class DjangoContext:
|
||||
if lookup_base.args and isinstance(lookup_base.args[0], Instance):
|
||||
lookup_type: MypyType = lookup_base.args[0]
|
||||
# if it's Field, consider lookup_type a __get__ of current field
|
||||
if (isinstance(lookup_type, Instance)
|
||||
and lookup_type.type.fullname == fullnames.FIELD_FULLNAME):
|
||||
if isinstance(lookup_type, Instance) and lookup_type.type.fullname == fullnames.FIELD_FULLNAME:
|
||||
field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
||||
is_nullable=field.null)
|
||||
lookup_type = helpers.get_private_descriptor_type(
|
||||
field_info, "_pyi_private_get_type", is_nullable=field.null
|
||||
)
|
||||
return lookup_type
|
||||
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
@@ -1,39 +1,34 @@
|
||||
MODEL_CLASS_FULLNAME = "django.db.models.base.Model"
|
||||
FIELD_FULLNAME = "django.db.models.fields.Field"
|
||||
CHAR_FIELD_FULLNAME = "django.db.models.fields.CharField"
|
||||
ARRAY_FIELD_FULLNAME = "django.contrib.postgres.fields.array.ArrayField"
|
||||
AUTO_FIELD_FULLNAME = "django.db.models.fields.AutoField"
|
||||
GENERIC_FOREIGN_KEY_FULLNAME = "django.contrib.contenttypes.fields.GenericForeignKey"
|
||||
FOREIGN_KEY_FULLNAME = "django.db.models.fields.related.ForeignKey"
|
||||
ONETOONE_FIELD_FULLNAME = "django.db.models.fields.related.OneToOneField"
|
||||
MANYTOMANY_FIELD_FULLNAME = "django.db.models.fields.related.ManyToManyField"
|
||||
DUMMY_SETTINGS_BASE_CLASS = "django.conf._DjangoConfLazyObject"
|
||||
|
||||
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
|
||||
FIELD_FULLNAME = 'django.db.models.fields.Field'
|
||||
CHAR_FIELD_FULLNAME = 'django.db.models.fields.CharField'
|
||||
ARRAY_FIELD_FULLNAME = 'django.contrib.postgres.fields.array.ArrayField'
|
||||
AUTO_FIELD_FULLNAME = 'django.db.models.fields.AutoField'
|
||||
GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey'
|
||||
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
|
||||
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
|
||||
MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField'
|
||||
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
|
||||
QUERYSET_CLASS_FULLNAME = "django.db.models.query.QuerySet"
|
||||
BASE_MANAGER_CLASS_FULLNAME = "django.db.models.manager.BaseManager"
|
||||
MANAGER_CLASS_FULLNAME = "django.db.models.manager.Manager"
|
||||
RELATED_MANAGER_CLASS = "django.db.models.manager.RelatedManager"
|
||||
|
||||
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
|
||||
BASE_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.BaseManager'
|
||||
MANAGER_CLASS_FULLNAME = 'django.db.models.manager.Manager'
|
||||
RELATED_MANAGER_CLASS = 'django.db.models.manager.RelatedManager'
|
||||
BASEFORM_CLASS_FULLNAME = "django.forms.forms.BaseForm"
|
||||
FORM_CLASS_FULLNAME = "django.forms.forms.Form"
|
||||
MODELFORM_CLASS_FULLNAME = "django.forms.models.ModelForm"
|
||||
|
||||
BASEFORM_CLASS_FULLNAME = 'django.forms.forms.BaseForm'
|
||||
FORM_CLASS_FULLNAME = 'django.forms.forms.Form'
|
||||
MODELFORM_CLASS_FULLNAME = 'django.forms.models.ModelForm'
|
||||
|
||||
FORM_MIXIN_CLASS_FULLNAME = 'django.views.generic.edit.FormMixin'
|
||||
FORM_MIXIN_CLASS_FULLNAME = "django.views.generic.edit.FormMixin"
|
||||
|
||||
MANAGER_CLASSES = {
|
||||
MANAGER_CLASS_FULLNAME,
|
||||
BASE_MANAGER_CLASS_FULLNAME,
|
||||
}
|
||||
|
||||
RELATED_FIELDS_CLASSES = {
|
||||
FOREIGN_KEY_FULLNAME,
|
||||
ONETOONE_FIELD_FULLNAME,
|
||||
MANYTOMANY_FIELD_FULLNAME
|
||||
}
|
||||
RELATED_FIELDS_CLASSES = {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME, MANYTOMANY_FIELD_FULLNAME}
|
||||
|
||||
MIGRATION_CLASS_FULLNAME = 'django.db.migrations.migration.Migration'
|
||||
OPTIONS_CLASS_FULLNAME = 'django.db.models.options.Options'
|
||||
HTTPREQUEST_CLASS_FULLNAME = 'django.http.request.HttpRequest'
|
||||
MIGRATION_CLASS_FULLNAME = "django.db.migrations.migration.Migration"
|
||||
OPTIONS_CLASS_FULLNAME = "django.db.models.options.Options"
|
||||
HTTPREQUEST_CLASS_FULLNAME = "django.http.request.HttpRequest"
|
||||
|
||||
F_EXPRESSION_FULLNAME = 'django.db.models.expressions.F'
|
||||
F_EXPRESSION_FULLNAME = "django.db.models.expressions.F"
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from collections import OrderedDict
|
||||
from typing import (
|
||||
TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
|
||||
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.fields.related import RelatedField
|
||||
@@ -10,11 +8,31 @@ from mypy import checker
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.mro import calculate_mro
|
||||
from mypy.nodes import (
|
||||
GDEF, MDEF, Argument, Block, ClassDef, Expression, FuncDef, MemberExpr, MypyFile, NameExpr, PlaceholderNode,
|
||||
StrExpr, SymbolNode, SymbolTable, SymbolTableNode, TypeInfo, Var,
|
||||
GDEF,
|
||||
MDEF,
|
||||
Argument,
|
||||
Block,
|
||||
ClassDef,
|
||||
Expression,
|
||||
FuncDef,
|
||||
MemberExpr,
|
||||
MypyFile,
|
||||
NameExpr,
|
||||
PlaceholderNode,
|
||||
StrExpr,
|
||||
SymbolNode,
|
||||
SymbolTable,
|
||||
SymbolTableNode,
|
||||
TypeInfo,
|
||||
Var,
|
||||
)
|
||||
from mypy.plugin import (
|
||||
AttributeContext, CheckerPluginInterface, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext,
|
||||
AttributeContext,
|
||||
CheckerPluginInterface,
|
||||
ClassDefContext,
|
||||
DynamicClassDefContext,
|
||||
FunctionContext,
|
||||
MethodContext,
|
||||
)
|
||||
from mypy.plugins.common import add_method
|
||||
from mypy.semanal import SemanticAnalyzer
|
||||
@@ -29,7 +47,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
|
||||
return model_info.metadata.setdefault('django', {})
|
||||
return model_info.metadata.setdefault("django", {})
|
||||
|
||||
|
||||
class IncompleteDefnException(Exception):
|
||||
@@ -37,9 +55,9 @@ class IncompleteDefnException(Exception):
|
||||
|
||||
|
||||
def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]:
|
||||
if '.' not in fullname:
|
||||
if "." not in fullname:
|
||||
return None
|
||||
module, cls_name = fullname.rsplit('.', 1)
|
||||
module, cls_name = fullname.rsplit(".", 1)
|
||||
|
||||
module_file = all_modules.get(module)
|
||||
if module_file is None:
|
||||
@@ -71,12 +89,11 @@ def lookup_class_typeinfo(api: TypeChecker, klass: type) -> Optional[TypeInfo]:
|
||||
|
||||
|
||||
def reparametrize_instance(instance: Instance, new_args: List[MypyType]) -> Instance:
|
||||
return Instance(instance.type, args=new_args,
|
||||
line=instance.line, column=instance.column)
|
||||
return Instance(instance.type, args=new_args, line=instance.line, column=instance.column)
|
||||
|
||||
|
||||
def get_class_fullname(klass: type) -> str:
|
||||
return klass.__module__ + '.' + klass.__qualname__
|
||||
return klass.__module__ + "." + klass.__qualname__
|
||||
|
||||
|
||||
def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]:
|
||||
@@ -115,9 +132,9 @@ def make_optional(typ: MypyType) -> MypyType:
|
||||
|
||||
def parse_bool(expr: Expression) -> Optional[bool]:
|
||||
if isinstance(expr, NameExpr):
|
||||
if expr.fullname == 'builtins.True':
|
||||
if expr.fullname == "builtins.True":
|
||||
return True
|
||||
if expr.fullname == 'builtins.False':
|
||||
if expr.fullname == "builtins.False":
|
||||
return False
|
||||
return None
|
||||
|
||||
@@ -164,27 +181,24 @@ def get_field_lookup_exact_type(api: TypeChecker, field: Field) -> MypyType:
|
||||
field_info = lookup_class_typeinfo(api, field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
return get_private_descriptor_type(field_info, '_pyi_lookup_exact_type',
|
||||
is_nullable=field.null)
|
||||
return get_private_descriptor_type(field_info, "_pyi_lookup_exact_type", is_nullable=field.null)
|
||||
|
||||
|
||||
def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]:
|
||||
metaclass_sym = info.names.get('Meta')
|
||||
metaclass_sym = info.names.get("Meta")
|
||||
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
|
||||
return metaclass_sym.node
|
||||
return None
|
||||
|
||||
|
||||
def add_new_class_for_module(module: MypyFile,
|
||||
name: str,
|
||||
bases: List[Instance],
|
||||
fields: Optional[Dict[str, MypyType]] = None
|
||||
) -> TypeInfo:
|
||||
def add_new_class_for_module(
|
||||
module: MypyFile, name: str, bases: List[Instance], fields: Optional[Dict[str, MypyType]] = None
|
||||
) -> TypeInfo:
|
||||
new_class_unique_name = checker.gen_unique_name(name, module.names)
|
||||
|
||||
# make new class expression
|
||||
classdef = ClassDef(new_class_unique_name, Block([]))
|
||||
classdef.fullname = module.fullname + '.' + new_class_unique_name
|
||||
classdef.fullname = module.fullname + "." + new_class_unique_name
|
||||
|
||||
# make new TypeInfo
|
||||
new_typeinfo = TypeInfo(SymbolTable(), classdef, module.fullname)
|
||||
@@ -197,7 +211,7 @@ def add_new_class_for_module(module: MypyFile,
|
||||
for field_name, field_type in fields.items():
|
||||
var = Var(field_name, type=field_type)
|
||||
var.info = new_typeinfo
|
||||
var._fullname = new_typeinfo.fullname + '.' + field_name
|
||||
var._fullname = new_typeinfo.fullname + "." + field_name
|
||||
new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True)
|
||||
|
||||
classdef.info = new_typeinfo
|
||||
@@ -215,18 +229,17 @@ def get_current_module(api: TypeChecker) -> MypyFile:
|
||||
return current_module
|
||||
|
||||
|
||||
def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]') -> TupleType:
|
||||
def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: "OrderedDict[str, MypyType]") -> TupleType:
|
||||
current_module = get_current_module(api)
|
||||
namedtuple_info = add_new_class_for_module(current_module, name,
|
||||
bases=[api.named_generic_type('typing.NamedTuple', [])],
|
||||
fields=fields)
|
||||
namedtuple_info = add_new_class_for_module(
|
||||
current_module, name, bases=[api.named_generic_type("typing.NamedTuple", [])], fields=fields
|
||||
)
|
||||
return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, []))
|
||||
|
||||
|
||||
def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType:
|
||||
def make_tuple(api: "TypeChecker", fields: List[MypyType]) -> TupleType:
|
||||
# fallback for tuples is any builtins.tuple instance
|
||||
fallback = api.named_generic_type('builtins.tuple',
|
||||
[AnyType(TypeOfAny.special_form)])
|
||||
fallback = api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.special_form)])
|
||||
return TupleType(fields, fallback=fallback)
|
||||
|
||||
|
||||
@@ -235,8 +248,7 @@ def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
|
||||
converted_items = []
|
||||
for item in typ.items:
|
||||
converted_items.append(convert_any_to_type(item, referred_to_type))
|
||||
return UnionType.make_union(converted_items,
|
||||
line=typ.line, column=typ.column)
|
||||
return UnionType.make_union(converted_items, line=typ.line, column=typ.column)
|
||||
if isinstance(typ, Instance):
|
||||
args = []
|
||||
for default_arg in typ.args:
|
||||
@@ -252,21 +264,22 @@ def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
|
||||
return typ
|
||||
|
||||
|
||||
def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyType]',
|
||||
required_keys: Set[str]) -> TypedDictType:
|
||||
object_type = api.named_generic_type('mypy_extensions._TypedDict', [])
|
||||
def make_typeddict(
|
||||
api: CheckerPluginInterface, fields: "OrderedDict[str, MypyType]", required_keys: Set[str]
|
||||
) -> TypedDictType:
|
||||
object_type = api.named_generic_type("mypy_extensions._TypedDict", [])
|
||||
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
|
||||
return typed_dict_type
|
||||
|
||||
|
||||
def resolve_string_attribute_value(attr_expr: Expression, django_context: 'DjangoContext') -> Optional[str]:
|
||||
def resolve_string_attribute_value(attr_expr: Expression, django_context: "DjangoContext") -> Optional[str]:
|
||||
if isinstance(attr_expr, StrExpr):
|
||||
return attr_expr.value
|
||||
|
||||
# support extracting from settings, in general case it's unresolvable yet
|
||||
if isinstance(attr_expr, MemberExpr):
|
||||
member_name = attr_expr.name
|
||||
if isinstance(attr_expr.expr, NameExpr) and attr_expr.expr.fullname == 'django.conf.settings':
|
||||
if isinstance(attr_expr.expr, NameExpr) and attr_expr.expr.fullname == "django.conf.settings":
|
||||
if hasattr(django_context.settings, member_name):
|
||||
return getattr(django_context.settings, member_name)
|
||||
return None
|
||||
@@ -274,27 +287,27 @@ def resolve_string_attribute_value(attr_expr: Expression, django_context: 'Djang
|
||||
|
||||
def get_semanal_api(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> SemanticAnalyzer:
|
||||
if not isinstance(ctx.api, SemanticAnalyzer):
|
||||
raise ValueError('Not a SemanticAnalyzer')
|
||||
raise ValueError("Not a SemanticAnalyzer")
|
||||
return ctx.api
|
||||
|
||||
|
||||
def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker:
|
||||
if not isinstance(ctx.api, TypeChecker):
|
||||
raise ValueError('Not a TypeChecker')
|
||||
raise ValueError("Not a TypeChecker")
|
||||
return ctx.api
|
||||
|
||||
|
||||
def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool:
|
||||
return (info.fullname in django_context.all_registered_model_class_fullnames
|
||||
or info.has_base(fullnames.MODEL_CLASS_FULLNAME))
|
||||
def is_model_subclass_info(info: TypeInfo, django_context: "DjangoContext") -> bool:
|
||||
return info.fullname in django_context.all_registered_model_class_fullnames or info.has_base(
|
||||
fullnames.MODEL_CLASS_FULLNAME
|
||||
)
|
||||
|
||||
|
||||
def check_types_compatible(ctx: Union[FunctionContext, MethodContext],
|
||||
*, expected_type: MypyType, actual_type: MypyType, error_message: str) -> None:
|
||||
def check_types_compatible(
|
||||
ctx: Union[FunctionContext, MethodContext], *, expected_type: MypyType, actual_type: MypyType, error_message: str
|
||||
) -> None:
|
||||
api = get_typechecker_api(ctx)
|
||||
api.check_subtype(actual_type, expected_type,
|
||||
ctx.context, error_message,
|
||||
'got', 'expected')
|
||||
api.check_subtype(actual_type, expected_type, ctx.context, error_message, "got", "expected")
|
||||
|
||||
|
||||
def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None:
|
||||
@@ -302,11 +315,10 @@ def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> No
|
||||
var = Var(name=name, type=sym_type)
|
||||
# var.info: type of the object variable is bound to
|
||||
var.info = info
|
||||
var._fullname = info.fullname + '.' + name
|
||||
var._fullname = info.fullname + "." + name
|
||||
var.is_initialized_in_class = True
|
||||
var.is_inferred = True
|
||||
info.names[name] = SymbolTableNode(MDEF, var,
|
||||
plugin_generated=True)
|
||||
info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True)
|
||||
|
||||
|
||||
def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument], MypyType]:
|
||||
@@ -322,8 +334,9 @@ def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument],
|
||||
return prepared_arguments, return_type
|
||||
|
||||
|
||||
def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance,
|
||||
new_method_name: str, method_node: FuncDef) -> None:
|
||||
def copy_method_to_another_class(
|
||||
ctx: ClassDefContext, self_type: Instance, new_method_name: str, method_node: FuncDef
|
||||
) -> None:
|
||||
semanal_api = get_semanal_api(ctx)
|
||||
if method_node.type is None:
|
||||
if not semanal_api.final_iteration:
|
||||
@@ -331,11 +344,7 @@ def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance,
|
||||
return
|
||||
|
||||
arguments, return_type = build_unannotated_method_args(method_node)
|
||||
add_method(ctx,
|
||||
new_method_name,
|
||||
args=arguments,
|
||||
return_type=return_type,
|
||||
self_type=self_type)
|
||||
add_method(ctx, new_method_name, args=arguments, return_type=return_type, self_type=self_type)
|
||||
return
|
||||
|
||||
method_type = method_node.type
|
||||
@@ -345,17 +354,16 @@ def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance,
|
||||
return
|
||||
|
||||
arguments = []
|
||||
bound_return_type = semanal_api.anal_type(method_type.ret_type,
|
||||
allow_placeholder=True)
|
||||
bound_return_type = semanal_api.anal_type(method_type.ret_type, allow_placeholder=True)
|
||||
|
||||
assert bound_return_type is not None
|
||||
|
||||
if isinstance(bound_return_type, PlaceholderNode):
|
||||
return
|
||||
|
||||
for arg_name, arg_type, original_argument in zip(method_type.arg_names[1:],
|
||||
method_type.arg_types[1:],
|
||||
method_node.arguments[1:]):
|
||||
for arg_name, arg_type, original_argument in zip(
|
||||
method_type.arg_names[1:], method_type.arg_types[1:], method_node.arguments[1:]
|
||||
):
|
||||
bound_arg_type = semanal_api.anal_type(arg_type, allow_placeholder=True)
|
||||
if bound_arg_type is None and not semanal_api.final_iteration:
|
||||
semanal_api.defer()
|
||||
@@ -366,19 +374,16 @@ def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance,
|
||||
if isinstance(bound_arg_type, PlaceholderNode):
|
||||
return
|
||||
|
||||
var = Var(name=original_argument.variable.name,
|
||||
type=arg_type)
|
||||
var = Var(name=original_argument.variable.name, type=arg_type)
|
||||
var.line = original_argument.variable.line
|
||||
var.column = original_argument.variable.column
|
||||
argument = Argument(variable=var,
|
||||
type_annotation=bound_arg_type,
|
||||
initializer=original_argument.initializer,
|
||||
kind=original_argument.kind)
|
||||
argument = Argument(
|
||||
variable=var,
|
||||
type_annotation=bound_arg_type,
|
||||
initializer=original_argument.initializer,
|
||||
kind=original_argument.kind,
|
||||
)
|
||||
argument.set_line(original_argument)
|
||||
arguments.append(argument)
|
||||
|
||||
add_method(ctx,
|
||||
new_method_name,
|
||||
args=arguments,
|
||||
return_type=bound_return_type,
|
||||
self_type=self_type)
|
||||
add_method(ctx, new_method_name, args=arguments, return_type=bound_return_type, self_type=self_type)
|
||||
|
||||
@@ -8,28 +8,28 @@ from mypy.modulefinder import mypy_path
|
||||
from mypy.nodes import MypyFile, TypeInfo
|
||||
from mypy.options import Options
|
||||
from mypy.plugin import (
|
||||
AttributeContext, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext, Plugin,
|
||||
AttributeContext,
|
||||
ClassDefContext,
|
||||
DynamicClassDefContext,
|
||||
FunctionContext,
|
||||
MethodContext,
|
||||
Plugin,
|
||||
)
|
||||
from mypy.types import Type as MypyType
|
||||
|
||||
import mypy_django_plugin.transformers.orm_lookups
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
from mypy_django_plugin.transformers import (
|
||||
fields, forms, init_create, meta, querysets, request, settings,
|
||||
)
|
||||
from mypy_django_plugin.transformers.managers import (
|
||||
create_new_manager_class_from_from_queryset_method,
|
||||
)
|
||||
from mypy_django_plugin.transformers import fields, forms, init_create, meta, querysets, request, settings
|
||||
from mypy_django_plugin.transformers.managers import create_new_manager_class_from_from_queryset_method
|
||||
from mypy_django_plugin.transformers.models import process_model_class
|
||||
|
||||
|
||||
def transform_model_class(ctx: ClassDefContext,
|
||||
django_context: DjangoContext) -> None:
|
||||
def transform_model_class(ctx: ClassDefContext, django_context: DjangoContext) -> None:
|
||||
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MODEL_CLASS_FULLNAME)
|
||||
|
||||
if sym is not None and isinstance(sym.node, TypeInfo):
|
||||
helpers.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1
|
||||
helpers.get_django_metadata(sym.node)["model_bases"][ctx.cls.fullname] = 1
|
||||
else:
|
||||
if not ctx.api.final_iteration:
|
||||
ctx.api.defer()
|
||||
@@ -41,7 +41,7 @@ def transform_model_class(ctx: ClassDefContext,
|
||||
def transform_form_class(ctx: ClassDefContext) -> None:
|
||||
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.BASEFORM_CLASS_FULLNAME)
|
||||
if sym is not None and isinstance(sym.node, TypeInfo):
|
||||
helpers.get_django_metadata(sym.node)['baseform_bases'][ctx.cls.fullname] = 1
|
||||
helpers.get_django_metadata(sym.node)["baseform_bases"][ctx.cls.fullname] = 1
|
||||
|
||||
forms.make_meta_nested_class_inherit_from_any(ctx)
|
||||
|
||||
@@ -49,11 +49,10 @@ def transform_form_class(ctx: ClassDefContext) -> None:
|
||||
def add_new_manager_base(ctx: ClassDefContext) -> None:
|
||||
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
|
||||
if sym is not None and isinstance(sym.node, TypeInfo):
|
||||
helpers.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1
|
||||
helpers.get_django_metadata(sym.node)["manager_bases"][ctx.cls.fullname] = 1
|
||||
|
||||
|
||||
def extract_django_settings_module(config_file_path: Optional[str]) -> str:
|
||||
|
||||
def exit(error_type: int) -> NoReturn:
|
||||
"""Using mypy's argument parser, raise `SystemExit` to fail hard if validation fails.
|
||||
|
||||
@@ -69,24 +68,28 @@ def extract_django_settings_module(config_file_path: Optional[str]) -> str:
|
||||
[mypy.plugins.django_stubs]
|
||||
django_settings_module: str (required)
|
||||
...
|
||||
""".replace("\n" + 8 * " ", "\n")
|
||||
handler = CapturableArgumentParser(prog='(django-stubs) mypy', usage=usage)
|
||||
messages = {1: 'mypy config file is not specified or found',
|
||||
2: 'no section [mypy.plugins.django-stubs]',
|
||||
3: 'the setting is not provided'}
|
||||
""".replace(
|
||||
"\n" + 8 * " ", "\n"
|
||||
)
|
||||
handler = CapturableArgumentParser(prog="(django-stubs) mypy", usage=usage)
|
||||
messages = {
|
||||
1: "mypy config file is not specified or found",
|
||||
2: "no section [mypy.plugins.django-stubs]",
|
||||
3: "the setting is not provided",
|
||||
}
|
||||
handler.error("'django_settings_module' is not set: " + messages[error_type])
|
||||
|
||||
parser = configparser.ConfigParser()
|
||||
try:
|
||||
parser.read_file(open(cast(str, config_file_path), 'r'), source=config_file_path)
|
||||
parser.read_file(open(cast(str, config_file_path)), source=config_file_path)
|
||||
except (IsADirectoryError, OSError):
|
||||
exit(1)
|
||||
|
||||
section = 'mypy.plugins.django-stubs'
|
||||
section = "mypy.plugins.django-stubs"
|
||||
if not parser.has_section(section):
|
||||
exit(2)
|
||||
settings = parser.get(section, 'django_settings_module', fallback=None) or exit(3)
|
||||
return cast(str, settings).strip('\'"')
|
||||
settings = parser.get(section, "django_settings_module", fallback=None) or exit(3)
|
||||
return cast(str, settings).strip("'\"")
|
||||
|
||||
|
||||
class NewSemanalDjangoPlugin(Plugin):
|
||||
@@ -102,34 +105,41 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
def _get_current_queryset_bases(self) -> Dict[str, int]:
|
||||
model_sym = self.lookup_fully_qualified(fullnames.QUERYSET_CLASS_FULLNAME)
|
||||
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
|
||||
return (helpers.get_django_metadata(model_sym.node)
|
||||
.setdefault('queryset_bases', {fullnames.QUERYSET_CLASS_FULLNAME: 1}))
|
||||
return helpers.get_django_metadata(model_sym.node).setdefault(
|
||||
"queryset_bases", {fullnames.QUERYSET_CLASS_FULLNAME: 1}
|
||||
)
|
||||
else:
|
||||
return {}
|
||||
|
||||
def _get_current_manager_bases(self) -> Dict[str, int]:
|
||||
model_sym = self.lookup_fully_qualified(fullnames.MANAGER_CLASS_FULLNAME)
|
||||
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
|
||||
return (helpers.get_django_metadata(model_sym.node)
|
||||
.setdefault('manager_bases', {fullnames.MANAGER_CLASS_FULLNAME: 1}))
|
||||
return helpers.get_django_metadata(model_sym.node).setdefault(
|
||||
"manager_bases", {fullnames.MANAGER_CLASS_FULLNAME: 1}
|
||||
)
|
||||
else:
|
||||
return {}
|
||||
|
||||
def _get_current_model_bases(self) -> Dict[str, int]:
|
||||
model_sym = self.lookup_fully_qualified(fullnames.MODEL_CLASS_FULLNAME)
|
||||
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
|
||||
return helpers.get_django_metadata(model_sym.node).setdefault('model_bases',
|
||||
{fullnames.MODEL_CLASS_FULLNAME: 1})
|
||||
return helpers.get_django_metadata(model_sym.node).setdefault(
|
||||
"model_bases", {fullnames.MODEL_CLASS_FULLNAME: 1}
|
||||
)
|
||||
else:
|
||||
return {}
|
||||
|
||||
def _get_current_form_bases(self) -> Dict[str, int]:
|
||||
model_sym = self.lookup_fully_qualified(fullnames.BASEFORM_CLASS_FULLNAME)
|
||||
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
|
||||
return (helpers.get_django_metadata(model_sym.node)
|
||||
.setdefault('baseform_bases', {fullnames.BASEFORM_CLASS_FULLNAME: 1,
|
||||
fullnames.FORM_CLASS_FULLNAME: 1,
|
||||
fullnames.MODELFORM_CLASS_FULLNAME: 1}))
|
||||
return helpers.get_django_metadata(model_sym.node).setdefault(
|
||||
"baseform_bases",
|
||||
{
|
||||
fullnames.BASEFORM_CLASS_FULLNAME: 1,
|
||||
fullnames.FORM_CLASS_FULLNAME: 1,
|
||||
fullnames.MODELFORM_CLASS_FULLNAME: 1,
|
||||
},
|
||||
)
|
||||
else:
|
||||
return {}
|
||||
|
||||
@@ -144,17 +154,16 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
|
||||
def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]:
|
||||
# for settings
|
||||
if file.fullname == 'django.conf' and self.django_context.django_settings_module:
|
||||
if file.fullname == "django.conf" and self.django_context.django_settings_module:
|
||||
return [self._new_dependency(self.django_context.django_settings_module)]
|
||||
|
||||
# for values / values_list
|
||||
if file.fullname == 'django.db.models':
|
||||
return [self._new_dependency('mypy_extensions'), self._new_dependency('typing')]
|
||||
if file.fullname == "django.db.models":
|
||||
return [self._new_dependency("mypy_extensions"), self._new_dependency("typing")]
|
||||
|
||||
# for `get_user_model()`
|
||||
if self.django_context.settings:
|
||||
if (file.fullname == 'django.contrib.auth'
|
||||
or file.fullname in {'django.http', 'django.http.request'}):
|
||||
if file.fullname == "django.contrib.auth" or file.fullname in {"django.http", "django.http.request"}:
|
||||
auth_user_model_name = self.django_context.settings.AUTH_USER_MODEL
|
||||
try:
|
||||
auth_user_module = self.django_context.apps_registry.get_model(auth_user_model_name).__module__
|
||||
@@ -186,9 +195,8 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
deps.add(self._new_dependency(related_model_module))
|
||||
return list(deps)
|
||||
|
||||
def get_function_hook(self, fullname: str
|
||||
) -> Optional[Callable[[FunctionContext], MypyType]]:
|
||||
if fullname == 'django.contrib.auth.get_user_model':
|
||||
def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], MypyType]]:
|
||||
if fullname == "django.contrib.auth.get_user_model":
|
||||
return partial(settings.get_user_model_hook, django_context=self.django_context)
|
||||
|
||||
manager_bases = self._get_current_manager_bases()
|
||||
@@ -204,46 +212,48 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)
|
||||
return None
|
||||
|
||||
def get_method_hook(self, fullname: str
|
||||
) -> Optional[Callable[[MethodContext], MypyType]]:
|
||||
class_fullname, _, method_name = fullname.rpartition('.')
|
||||
if method_name == 'get_form_class':
|
||||
def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], MypyType]]:
|
||||
class_fullname, _, method_name = fullname.rpartition(".")
|
||||
if method_name == "get_form_class":
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
|
||||
return forms.extract_proper_type_for_get_form_class
|
||||
|
||||
if method_name == 'get_form':
|
||||
if method_name == "get_form":
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
|
||||
return forms.extract_proper_type_for_get_form
|
||||
|
||||
if method_name == 'values':
|
||||
if method_name == "values":
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
|
||||
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
|
||||
|
||||
if method_name == 'values_list':
|
||||
if method_name == "values_list":
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
|
||||
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
|
||||
|
||||
if method_name == 'get_field':
|
||||
if method_name == "get_field":
|
||||
info = self._get_typeinfo_or_none(class_fullname)
|
||||
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
|
||||
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)
|
||||
|
||||
manager_classes = self._get_current_manager_bases()
|
||||
if class_fullname in manager_classes and method_name == 'create':
|
||||
if class_fullname in manager_classes and method_name == "create":
|
||||
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
|
||||
if class_fullname in manager_classes and method_name in {'filter', 'get', 'exclude'}:
|
||||
return partial(mypy_django_plugin.transformers.orm_lookups.typecheck_queryset_filter,
|
||||
django_context=self.django_context)
|
||||
if class_fullname in manager_classes and method_name in {"filter", "get", "exclude"}:
|
||||
return partial(
|
||||
mypy_django_plugin.transformers.orm_lookups.typecheck_queryset_filter,
|
||||
django_context=self.django_context,
|
||||
)
|
||||
return None
|
||||
|
||||
def get_base_class_hook(self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
if (fullname in self.django_context.all_registered_model_class_fullnames
|
||||
or fullname in self._get_current_model_bases()):
|
||||
def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
if (
|
||||
fullname in self.django_context.all_registered_model_class_fullnames
|
||||
or fullname in self._get_current_model_bases()
|
||||
):
|
||||
return partial(transform_model_class, django_context=self.django_context)
|
||||
|
||||
if fullname in self._get_current_manager_bases():
|
||||
@@ -253,22 +263,19 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
return transform_form_class
|
||||
return None
|
||||
|
||||
def get_attribute_hook(self, fullname: str
|
||||
) -> Optional[Callable[[AttributeContext], MypyType]]:
|
||||
class_name, _, attr_name = fullname.rpartition('.')
|
||||
def get_attribute_hook(self, fullname: str) -> Optional[Callable[[AttributeContext], MypyType]]:
|
||||
class_name, _, attr_name = fullname.rpartition(".")
|
||||
if class_name == fullnames.DUMMY_SETTINGS_BASE_CLASS:
|
||||
return partial(settings.get_type_of_settings_attribute,
|
||||
django_context=self.django_context)
|
||||
return partial(settings.get_type_of_settings_attribute, django_context=self.django_context)
|
||||
|
||||
info = self._get_typeinfo_or_none(class_name)
|
||||
if info and info.has_base(fullnames.HTTPREQUEST_CLASS_FULLNAME) and attr_name == 'user':
|
||||
if info and info.has_base(fullnames.HTTPREQUEST_CLASS_FULLNAME) and attr_name == "user":
|
||||
return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context)
|
||||
return None
|
||||
|
||||
def get_dynamic_class_hook(self, fullname: str
|
||||
) -> Optional[Callable[[DynamicClassDefContext], None]]:
|
||||
if fullname.endswith('from_queryset'):
|
||||
class_name, _, _ = fullname.rpartition('.')
|
||||
def get_dynamic_class_hook(self, fullname: str) -> Optional[Callable[[DynamicClassDefContext], None]]:
|
||||
if fullname.endswith("from_queryset"):
|
||||
class_name, _, _ = fullname.rpartition(".")
|
||||
info = self._get_typeinfo_or_none(class_name)
|
||||
if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
|
||||
return create_new_manager_class_from_from_queryset_method
|
||||
|
||||
@@ -14,8 +14,7 @@ from mypy_django_plugin.lib import fullnames, helpers
|
||||
|
||||
def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]:
|
||||
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
|
||||
if (outer_model_info is None
|
||||
or not helpers.is_model_subclass_info(outer_model_info, django_context)):
|
||||
if outer_model_info is None or not helpers.is_model_subclass_info(outer_model_info, django_context):
|
||||
return None
|
||||
|
||||
field_name = None
|
||||
@@ -60,8 +59,7 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
|
||||
|
||||
# self reference with abstract=True on the model where ForeignKey is defined
|
||||
current_model_cls = current_field.model
|
||||
if (current_model_cls._meta.abstract
|
||||
and current_model_cls == related_model_cls):
|
||||
if current_model_cls._meta.abstract and current_model_cls == related_model_cls:
|
||||
# for all derived non-abstract classes, set variable with this name to
|
||||
# __get__/__set__ of ForeignKey of derived model
|
||||
for model_cls in django_context.all_registered_model_classes:
|
||||
@@ -69,11 +67,10 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
|
||||
derived_model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls)
|
||||
if derived_model_info is not None:
|
||||
fk_ref_type = Instance(derived_model_info, [])
|
||||
derived_fk_type = reparametrize_related_field_type(default_related_field_type,
|
||||
set_type=fk_ref_type, get_type=fk_ref_type)
|
||||
helpers.add_new_sym_for_info(derived_model_info,
|
||||
name=current_field.name,
|
||||
sym_type=derived_fk_type)
|
||||
derived_fk_type = reparametrize_related_field_type(
|
||||
default_related_field_type, set_type=fk_ref_type, get_type=fk_ref_type
|
||||
)
|
||||
helpers.add_new_sym_for_info(derived_model_info, name=current_field.name, sym_type=derived_fk_type)
|
||||
|
||||
related_model = related_model_cls
|
||||
related_model_to_set = related_model_cls
|
||||
@@ -97,16 +94,14 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
|
||||
related_model_to_set_type = Instance(related_model_to_set_info, []) # type: ignore
|
||||
|
||||
# replace Any with referred_to_type
|
||||
return reparametrize_related_field_type(default_related_field_type,
|
||||
set_type=related_model_to_set_type,
|
||||
get_type=related_model_type)
|
||||
return reparametrize_related_field_type(
|
||||
default_related_field_type, set_type=related_model_to_set_type, get_type=related_model_type
|
||||
)
|
||||
|
||||
|
||||
def get_field_descriptor_types(field_info: TypeInfo, is_nullable: bool) -> Tuple[MypyType, MypyType]:
|
||||
set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
|
||||
is_nullable=is_nullable)
|
||||
get_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
||||
is_nullable=is_nullable)
|
||||
set_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_set_type", is_nullable=is_nullable)
|
||||
get_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_nullable)
|
||||
return set_type, get_type
|
||||
|
||||
|
||||
@@ -114,7 +109,7 @@ def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
|
||||
default_return_type = cast(Instance, ctx.default_return_type)
|
||||
|
||||
is_nullable = False
|
||||
null_expr = helpers.get_call_argument_by_name(ctx, 'null')
|
||||
null_expr = helpers.get_call_argument_by_name(ctx, "null")
|
||||
if null_expr is not None:
|
||||
is_nullable = helpers.parse_bool(null_expr) or False
|
||||
|
||||
@@ -125,7 +120,7 @@ def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
|
||||
def determine_type_of_array_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
|
||||
default_return_type = set_descriptor_types_for_field(ctx)
|
||||
|
||||
base_field_arg_type = helpers.get_call_argument_type_by_name(ctx, 'base_field')
|
||||
base_field_arg_type = helpers.get_call_argument_type_by_name(ctx, "base_field")
|
||||
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
|
||||
return default_return_type
|
||||
|
||||
@@ -142,8 +137,7 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan
|
||||
assert isinstance(default_return_type, Instance)
|
||||
|
||||
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
|
||||
if (outer_model_info is None
|
||||
or not helpers.is_model_subclass_info(outer_model_info, django_context)):
|
||||
if outer_model_info is None or not helpers.is_model_subclass_info(outer_model_info, django_context):
|
||||
return ctx.default_return_type
|
||||
|
||||
assert isinstance(outer_model_info, TypeInfo)
|
||||
|
||||
@@ -18,7 +18,7 @@ def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None:
|
||||
|
||||
|
||||
def get_specified_form_class(object_type: Instance) -> Optional[TypeType]:
|
||||
form_class_sym = object_type.type.get('form_class')
|
||||
form_class_sym = object_type.type.get("form_class")
|
||||
if form_class_sym and isinstance(form_class_sym.type, CallableType):
|
||||
return TypeType(form_class_sym.type.ret_type)
|
||||
return None
|
||||
@@ -28,7 +28,7 @@ def extract_proper_type_for_get_form(ctx: MethodContext) -> MypyType:
|
||||
object_type = ctx.type
|
||||
assert isinstance(object_type, Instance)
|
||||
|
||||
form_class_type = helpers.get_call_argument_type_by_name(ctx, 'form_class')
|
||||
form_class_type = helpers.get_call_argument_type_by_name(ctx, "form_class")
|
||||
if form_class_type is None or isinstance(form_class_type, NoneTyp):
|
||||
form_class_type = get_specified_form_class(object_type)
|
||||
|
||||
|
||||
@@ -9,13 +9,14 @@ from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import helpers
|
||||
|
||||
|
||||
def get_actual_types(ctx: Union[MethodContext, FunctionContext],
|
||||
expected_keys: List[str]) -> List[Tuple[str, MypyType]]:
|
||||
def get_actual_types(
|
||||
ctx: Union[MethodContext, FunctionContext], expected_keys: List[str]
|
||||
) -> List[Tuple[str, MypyType]]:
|
||||
actual_types = []
|
||||
# positionals
|
||||
for pos, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
|
||||
if actual_name is None:
|
||||
if ctx.callee_arg_names[0] == 'kwargs':
|
||||
if ctx.callee_arg_names[0] == "kwargs":
|
||||
# unpacked dict as kwargs is not supported
|
||||
continue
|
||||
actual_name = expected_keys[pos]
|
||||
@@ -30,23 +31,23 @@ def get_actual_types(ctx: Union[MethodContext, FunctionContext],
|
||||
return actual_types
|
||||
|
||||
|
||||
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
|
||||
model_cls: Type[Model], method: str) -> MypyType:
|
||||
def typecheck_model_method(
|
||||
ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext, model_cls: Type[Model], method: str
|
||||
) -> MypyType:
|
||||
typechecker_api = helpers.get_typechecker_api(ctx)
|
||||
expected_types = django_context.get_expected_types(typechecker_api, model_cls, method=method)
|
||||
expected_keys = [key for key in expected_types.keys() if key != 'pk']
|
||||
expected_keys = [key for key in expected_types.keys() if key != "pk"]
|
||||
|
||||
for actual_name, actual_type in get_actual_types(ctx, expected_keys):
|
||||
if actual_name not in expected_types:
|
||||
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
|
||||
model_cls.__name__),
|
||||
ctx.context)
|
||||
ctx.api.fail(f'Unexpected attribute "{actual_name}" for model "{model_cls.__name__}"', ctx.context)
|
||||
continue
|
||||
helpers.check_types_compatible(ctx,
|
||||
expected_type=expected_types[actual_name],
|
||||
actual_type=actual_type,
|
||||
error_message='Incompatible type for "{}" of "{}"'.format(actual_name,
|
||||
model_cls.__name__))
|
||||
helpers.check_types_compatible(
|
||||
ctx,
|
||||
expected_type=expected_types[actual_name],
|
||||
actual_type=actual_type,
|
||||
error_message=f'Incompatible type for "{actual_name}" of "{model_cls.__name__}"',
|
||||
)
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
@@ -59,7 +60,7 @@ def redefine_and_typecheck_model_init(ctx: FunctionContext, django_context: Djan
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
return typecheck_model_method(ctx, django_context, model_cls, '__init__')
|
||||
return typecheck_model_method(ctx, django_context, model_cls, "__init__")
|
||||
|
||||
|
||||
def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||
@@ -72,4 +73,4 @@ def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: Djan
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
return typecheck_model_method(ctx, django_context, model_cls, 'create')
|
||||
return typecheck_model_method(ctx, django_context, model_cls, "create")
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from mypy.nodes import (
|
||||
GDEF, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo,
|
||||
)
|
||||
from mypy.nodes import GDEF, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo
|
||||
from mypy.plugin import ClassDefContext, DynamicClassDefContext
|
||||
from mypy.types import AnyType, Instance, TypeOfAny
|
||||
|
||||
@@ -21,16 +19,15 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
|
||||
return
|
||||
|
||||
assert isinstance(base_manager_info, TypeInfo)
|
||||
new_manager_info = semanal_api.basic_new_typeinfo(ctx.name,
|
||||
basetype_or_fallback=Instance(base_manager_info,
|
||||
[AnyType(TypeOfAny.unannotated)]))
|
||||
new_manager_info = semanal_api.basic_new_typeinfo(
|
||||
ctx.name, basetype_or_fallback=Instance(base_manager_info, [AnyType(TypeOfAny.unannotated)])
|
||||
)
|
||||
new_manager_info.line = ctx.call.line
|
||||
new_manager_info.defn.line = ctx.call.line
|
||||
new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type()
|
||||
|
||||
current_module = semanal_api.cur_mod_node
|
||||
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info,
|
||||
plugin_generated=True)
|
||||
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
|
||||
passed_queryset = ctx.call.args[0]
|
||||
assert isinstance(passed_queryset, NameExpr)
|
||||
|
||||
@@ -55,15 +52,14 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
|
||||
assert isinstance(expr, StrExpr)
|
||||
custom_manager_generated_name = expr.value
|
||||
else:
|
||||
custom_manager_generated_name = base_manager_info.name + 'From' + derived_queryset_info.name
|
||||
custom_manager_generated_name = base_manager_info.name + "From" + derived_queryset_info.name
|
||||
|
||||
custom_manager_generated_fullname = '.'.join(['django.db.models.manager', custom_manager_generated_name])
|
||||
if 'from_queryset_managers' not in base_manager_info.metadata:
|
||||
base_manager_info.metadata['from_queryset_managers'] = {}
|
||||
base_manager_info.metadata['from_queryset_managers'][custom_manager_generated_fullname] = new_manager_info.fullname
|
||||
custom_manager_generated_fullname = ".".join(["django.db.models.manager", custom_manager_generated_name])
|
||||
if "from_queryset_managers" not in base_manager_info.metadata:
|
||||
base_manager_info.metadata["from_queryset_managers"] = {}
|
||||
base_manager_info.metadata["from_queryset_managers"][custom_manager_generated_fullname] = new_manager_info.fullname
|
||||
|
||||
class_def_context = ClassDefContext(cls=new_manager_info.defn,
|
||||
reason=ctx.call, api=semanal_api)
|
||||
class_def_context = ClassDefContext(cls=new_manager_info.defn, reason=ctx.call, api=semanal_api)
|
||||
self_type = Instance(new_manager_info, [])
|
||||
# we need to copy all methods in MRO before django.db.models.query.QuerySet
|
||||
for class_mro_info in derived_queryset_info.mro:
|
||||
@@ -71,7 +67,6 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
|
||||
break
|
||||
for name, sym in class_mro_info.names.items():
|
||||
if isinstance(sym.node, FuncDef):
|
||||
helpers.copy_method_to_another_class(class_def_context,
|
||||
self_type,
|
||||
new_method_name=name,
|
||||
method_node=sym.node)
|
||||
helpers.copy_method_to_another_class(
|
||||
class_def_context, self_type, new_method_name=name, method_node=sym.node
|
||||
)
|
||||
|
||||
@@ -9,8 +9,7 @@ from mypy_django_plugin.lib import helpers
|
||||
|
||||
|
||||
def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType:
|
||||
field_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx),
|
||||
field_fullname)
|
||||
field_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx), field_fullname)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)])
|
||||
@@ -32,7 +31,7 @@ def return_proper_field_type_from_get_field(ctx: MethodContext, django_context:
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name')
|
||||
field_name_expr = helpers.get_call_argument_by_name(ctx, "field_name")
|
||||
if field_name_expr is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
@@ -3,9 +3,7 @@ from typing import Dict, List, Optional, Type, cast
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.fields import DateField, DateTimeField
|
||||
from django.db.models.fields.related import ForeignKey
|
||||
from django.db.models.fields.reverse_related import (
|
||||
ManyToManyRel, ManyToOneRel, OneToOneRel,
|
||||
)
|
||||
from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel
|
||||
from mypy.nodes import ARG_STAR2, Argument, Context, FuncDef, TypeInfo, Var
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.plugins import common
|
||||
@@ -35,7 +33,7 @@ class ModelClassInitializer:
|
||||
def lookup_typeinfo_or_incomplete_defn_error(self, fullname: str) -> TypeInfo:
|
||||
info = self.lookup_typeinfo(fullname)
|
||||
if info is None:
|
||||
raise helpers.IncompleteDefnException(f'No {fullname!r} found')
|
||||
raise helpers.IncompleteDefnException(f"No {fullname!r} found")
|
||||
return info
|
||||
|
||||
def lookup_class_typeinfo_or_incomplete_defn_error(self, klass: type) -> TypeInfo:
|
||||
@@ -48,20 +46,17 @@ class ModelClassInitializer:
|
||||
var = Var(name=name, type=typ)
|
||||
# var.info: type of the object variable is bound to
|
||||
var.info = self.model_classdef.info
|
||||
var._fullname = self.model_classdef.info.fullname + '.' + name
|
||||
var._fullname = self.model_classdef.info.fullname + "." + name
|
||||
var.is_initialized_in_class = True
|
||||
var.is_inferred = True
|
||||
return var
|
||||
|
||||
def add_new_node_to_model_class(self, name: str, typ: MypyType) -> None:
|
||||
helpers.add_new_sym_for_info(self.model_classdef.info,
|
||||
name=name,
|
||||
sym_type=typ)
|
||||
helpers.add_new_sym_for_info(self.model_classdef.info, name=name, sym_type=typ)
|
||||
|
||||
def add_new_class_for_current_module(self, name: str, bases: List[Instance]) -> TypeInfo:
|
||||
current_module = self.api.modules[self.model_classdef.info.module_name]
|
||||
new_class_info = helpers.add_new_class_for_module(current_module,
|
||||
name=name, bases=bases)
|
||||
new_class_info = helpers.add_new_class_for_module(current_module, name=name, bases=bases)
|
||||
return new_class_info
|
||||
|
||||
def run(self) -> None:
|
||||
@@ -103,8 +98,7 @@ class AddDefaultPrimaryKey(ModelClassInitializer):
|
||||
auto_field_info = self.lookup_typeinfo_or_incomplete_defn_error(auto_field_fullname)
|
||||
|
||||
set_type, get_type = fields.get_field_descriptor_types(auto_field_info, is_nullable=False)
|
||||
self.add_new_node_to_model_class(auto_field.attname, Instance(auto_field_info,
|
||||
[set_type, get_type]))
|
||||
self.add_new_node_to_model_class(auto_field.attname, Instance(auto_field_info, [set_type, get_type]))
|
||||
|
||||
|
||||
class AddRelatedModelsId(ModelClassInitializer):
|
||||
@@ -117,11 +111,11 @@ class AddRelatedModelsId(ModelClassInitializer):
|
||||
field_sym = self.ctx.cls.info.get(field.name)
|
||||
if field_sym is not None and field_sym.node is not None:
|
||||
error_context = field_sym.node
|
||||
self.api.fail(f'Cannot find model {field.related_model!r} '
|
||||
f'referenced in field {field.name!r} ',
|
||||
ctx=error_context)
|
||||
self.add_new_node_to_model_class(field.attname,
|
||||
AnyType(TypeOfAny.explicit))
|
||||
self.api.fail(
|
||||
f"Cannot find model {field.related_model!r} " f"referenced in field {field.name!r} ",
|
||||
ctx=error_context,
|
||||
)
|
||||
self.add_new_node_to_model_class(field.attname, AnyType(TypeOfAny.explicit))
|
||||
continue
|
||||
|
||||
if related_model_cls._meta.abstract:
|
||||
@@ -138,8 +132,7 @@ class AddRelatedModelsId(ModelClassInitializer):
|
||||
|
||||
is_nullable = self.django_context.get_field_nullability(field, None)
|
||||
set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
|
||||
self.add_new_node_to_model_class(field.attname,
|
||||
Instance(field_info, [set_type, get_type]))
|
||||
self.add_new_node_to_model_class(field.attname, Instance(field_info, [set_type, get_type]))
|
||||
|
||||
|
||||
class AddManagers(ModelClassInitializer):
|
||||
@@ -154,10 +147,9 @@ class AddManagers(ModelClassInitializer):
|
||||
|
||||
def get_generated_manager_mappings(self, base_manager_fullname: str) -> Dict[str, str]:
|
||||
base_manager_info = self.lookup_typeinfo(base_manager_fullname)
|
||||
if (base_manager_info is None
|
||||
or 'from_queryset_managers' not in base_manager_info.metadata):
|
||||
if base_manager_info is None or "from_queryset_managers" not in base_manager_info.metadata:
|
||||
return {}
|
||||
return base_manager_info.metadata['from_queryset_managers']
|
||||
return base_manager_info.metadata["from_queryset_managers"]
|
||||
|
||||
def create_new_model_parametrized_manager(self, name: str, base_manager_info: TypeInfo) -> Instance:
|
||||
bases = []
|
||||
@@ -166,31 +158,27 @@ class AddManagers(ModelClassInitializer):
|
||||
if original_base.type is None:
|
||||
raise helpers.IncompleteDefnException()
|
||||
|
||||
original_base = helpers.reparametrize_instance(original_base,
|
||||
[Instance(self.model_classdef.info, [])])
|
||||
original_base = helpers.reparametrize_instance(original_base, [Instance(self.model_classdef.info, [])])
|
||||
bases.append(original_base)
|
||||
|
||||
new_manager_info = self.add_new_class_for_current_module(name, bases)
|
||||
# copy fields to a new manager
|
||||
new_cls_def_context = ClassDefContext(cls=new_manager_info.defn,
|
||||
reason=self.ctx.reason,
|
||||
api=self.api)
|
||||
new_cls_def_context = ClassDefContext(cls=new_manager_info.defn, reason=self.ctx.reason, api=self.api)
|
||||
custom_manager_type = Instance(new_manager_info, [Instance(self.model_classdef.info, [])])
|
||||
|
||||
for name, sym in base_manager_info.names.items():
|
||||
# replace self type with new class, if copying method
|
||||
if isinstance(sym.node, FuncDef):
|
||||
helpers.copy_method_to_another_class(new_cls_def_context,
|
||||
self_type=custom_manager_type,
|
||||
new_method_name=name,
|
||||
method_node=sym.node)
|
||||
helpers.copy_method_to_another_class(
|
||||
new_cls_def_context, self_type=custom_manager_type, new_method_name=name, method_node=sym.node
|
||||
)
|
||||
continue
|
||||
|
||||
new_sym = sym.copy()
|
||||
if isinstance(new_sym.node, Var):
|
||||
new_var = Var(name, type=sym.type)
|
||||
new_var.info = new_manager_info
|
||||
new_var._fullname = new_manager_info.fullname + '.' + name
|
||||
new_var._fullname = new_manager_info.fullname + "." + name
|
||||
new_sym.node = new_var
|
||||
new_manager_info.names[name] = new_sym
|
||||
|
||||
@@ -215,7 +203,7 @@ class AddManagers(ModelClassInitializer):
|
||||
manager_info = self.lookup_typeinfo(real_manager_fullname) # type: ignore
|
||||
if manager_info is None:
|
||||
continue
|
||||
manager_class_name = real_manager_fullname.rsplit('.', maxsplit=1)[1]
|
||||
manager_class_name = real_manager_fullname.rsplit(".", maxsplit=1)[1]
|
||||
|
||||
if manager_name not in self.model_classdef.info.names:
|
||||
manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])])
|
||||
@@ -225,10 +213,11 @@ class AddManagers(ModelClassInitializer):
|
||||
if not self.has_any_parametrized_manager_as_base(manager_info):
|
||||
continue
|
||||
|
||||
custom_model_manager_name = manager.model.__name__ + '_' + manager_class_name
|
||||
custom_model_manager_name = manager.model.__name__ + "_" + manager_class_name
|
||||
try:
|
||||
custom_manager_type = self.create_new_model_parametrized_manager(custom_model_manager_name,
|
||||
base_manager_info=manager_info)
|
||||
custom_manager_type = self.create_new_model_parametrized_manager(
|
||||
custom_model_manager_name, base_manager_info=manager_info
|
||||
)
|
||||
except helpers.IncompleteDefnException:
|
||||
continue
|
||||
|
||||
@@ -238,11 +227,11 @@ class AddManagers(ModelClassInitializer):
|
||||
class AddDefaultManagerAttribute(ModelClassInitializer):
|
||||
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
||||
# add _default_manager
|
||||
if '_default_manager' not in self.model_classdef.info.names:
|
||||
if "_default_manager" not in self.model_classdef.info.names:
|
||||
default_manager_fullname = helpers.get_class_fullname(model_cls._meta.default_manager.__class__)
|
||||
default_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(default_manager_fullname)
|
||||
default_manager = Instance(default_manager_info, [Instance(self.model_classdef.info, [])])
|
||||
self.add_new_node_to_model_class('_default_manager', default_manager)
|
||||
self.add_new_node_to_model_class("_default_manager", default_manager)
|
||||
|
||||
|
||||
class AddRelatedManagers(ModelClassInitializer):
|
||||
@@ -272,8 +261,10 @@ class AddRelatedManagers(ModelClassInitializer):
|
||||
|
||||
if isinstance(relation, (ManyToOneRel, ManyToManyRel)):
|
||||
try:
|
||||
related_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.RELATED_MANAGER_CLASS) # noqa: E501
|
||||
if 'objects' not in related_model_info.names:
|
||||
related_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(
|
||||
fullnames.RELATED_MANAGER_CLASS
|
||||
) # noqa: E501
|
||||
if "objects" not in related_model_info.names:
|
||||
raise helpers.IncompleteDefnException()
|
||||
except helpers.IncompleteDefnException as exc:
|
||||
if not self.api.final_iteration:
|
||||
@@ -282,16 +273,17 @@ class AddRelatedManagers(ModelClassInitializer):
|
||||
continue
|
||||
|
||||
# create new RelatedManager subclass
|
||||
parametrized_related_manager_type = Instance(related_manager_info,
|
||||
[Instance(related_model_info, [])])
|
||||
default_manager_type = related_model_info.names['objects'].type
|
||||
if (default_manager_type is None
|
||||
or not isinstance(default_manager_type, Instance)
|
||||
or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME):
|
||||
parametrized_related_manager_type = Instance(related_manager_info, [Instance(related_model_info, [])])
|
||||
default_manager_type = related_model_info.names["objects"].type
|
||||
if (
|
||||
default_manager_type is None
|
||||
or not isinstance(default_manager_type, Instance)
|
||||
or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME
|
||||
):
|
||||
self.add_new_node_to_model_class(attname, parametrized_related_manager_type)
|
||||
continue
|
||||
|
||||
name = related_model_cls.__name__ + '_' + 'RelatedManager'
|
||||
name = related_model_cls.__name__ + "_" + "RelatedManager"
|
||||
bases = [parametrized_related_manager_type, default_manager_type]
|
||||
new_related_manager_info = self.add_new_class_for_current_module(name, bases)
|
||||
|
||||
@@ -303,45 +295,50 @@ class AddExtraFieldMethods(ModelClassInitializer):
|
||||
# get_FOO_display for choices
|
||||
for field in self.django_context.get_model_fields(model_cls):
|
||||
if field.choices:
|
||||
info = self.lookup_typeinfo_or_incomplete_defn_error('builtins.str')
|
||||
info = self.lookup_typeinfo_or_incomplete_defn_error("builtins.str")
|
||||
return_type = Instance(info, [])
|
||||
common.add_method(self.ctx,
|
||||
name='get_{}_display'.format(field.attname),
|
||||
args=[],
|
||||
return_type=return_type)
|
||||
common.add_method(self.ctx, name=f"get_{field.attname}_display", args=[], return_type=return_type)
|
||||
|
||||
# get_next_by, get_previous_by for Date, DateTime
|
||||
for field in self.django_context.get_model_fields(model_cls):
|
||||
if isinstance(field, (DateField, DateTimeField)) and not field.null:
|
||||
return_type = Instance(self.model_classdef.info, [])
|
||||
common.add_method(self.ctx,
|
||||
name='get_next_by_{}'.format(field.attname),
|
||||
args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)),
|
||||
AnyType(TypeOfAny.explicit),
|
||||
initializer=None,
|
||||
kind=ARG_STAR2)],
|
||||
return_type=return_type)
|
||||
common.add_method(self.ctx,
|
||||
name='get_previous_by_{}'.format(field.attname),
|
||||
args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)),
|
||||
AnyType(TypeOfAny.explicit),
|
||||
initializer=None,
|
||||
kind=ARG_STAR2)],
|
||||
return_type=return_type)
|
||||
common.add_method(
|
||||
self.ctx,
|
||||
name=f"get_next_by_{field.attname}",
|
||||
args=[
|
||||
Argument(
|
||||
Var("kwargs", AnyType(TypeOfAny.explicit)),
|
||||
AnyType(TypeOfAny.explicit),
|
||||
initializer=None,
|
||||
kind=ARG_STAR2,
|
||||
)
|
||||
],
|
||||
return_type=return_type,
|
||||
)
|
||||
common.add_method(
|
||||
self.ctx,
|
||||
name=f"get_previous_by_{field.attname}",
|
||||
args=[
|
||||
Argument(
|
||||
Var("kwargs", AnyType(TypeOfAny.explicit)),
|
||||
AnyType(TypeOfAny.explicit),
|
||||
initializer=None,
|
||||
kind=ARG_STAR2,
|
||||
)
|
||||
],
|
||||
return_type=return_type,
|
||||
)
|
||||
|
||||
|
||||
class AddMetaOptionsAttribute(ModelClassInitializer):
|
||||
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
||||
if '_meta' not in self.model_classdef.info.names:
|
||||
if "_meta" not in self.model_classdef.info.names:
|
||||
options_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.OPTIONS_CLASS_FULLNAME)
|
||||
self.add_new_node_to_model_class('_meta',
|
||||
Instance(options_info, [
|
||||
Instance(self.model_classdef.info, [])
|
||||
]))
|
||||
self.add_new_node_to_model_class("_meta", Instance(options_info, [Instance(self.model_classdef.info, [])]))
|
||||
|
||||
|
||||
def process_model_class(ctx: ClassDefContext,
|
||||
django_context: DjangoContext) -> None:
|
||||
def process_model_class(ctx: ClassDefContext, django_context: DjangoContext) -> None:
|
||||
initializers = [
|
||||
InjectAnyAsBaseForNestedMeta,
|
||||
AddDefaultPrimaryKey,
|
||||
|
||||
@@ -24,21 +24,24 @@ def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext)
|
||||
for lookup_kwarg, provided_type in zip(lookup_kwargs, provided_lookup_types):
|
||||
if lookup_kwarg is None:
|
||||
continue
|
||||
if (isinstance(provided_type, Instance)
|
||||
and provided_type.type.has_base('django.db.models.expressions.Combinable')):
|
||||
if isinstance(provided_type, Instance) and provided_type.type.has_base(
|
||||
"django.db.models.expressions.Combinable"
|
||||
):
|
||||
provided_type = resolve_combinable_type(provided_type, django_context)
|
||||
|
||||
lookup_type = django_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg)
|
||||
# Managers as provided_type is not supported yet
|
||||
if (isinstance(provided_type, Instance)
|
||||
and helpers.has_any_of_bases(provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME,
|
||||
fullnames.QUERYSET_CLASS_FULLNAME))):
|
||||
if isinstance(provided_type, Instance) and helpers.has_any_of_bases(
|
||||
provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME, fullnames.QUERYSET_CLASS_FULLNAME)
|
||||
):
|
||||
return ctx.default_return_type
|
||||
|
||||
helpers.check_types_compatible(ctx,
|
||||
expected_type=lookup_type,
|
||||
actual_type=provided_type,
|
||||
error_message=f'Incompatible type for lookup {lookup_kwarg!r}:')
|
||||
helpers.check_types_compatible(
|
||||
ctx,
|
||||
expected_type=lookup_type,
|
||||
actual_type=provided_type,
|
||||
error_message=f"Incompatible type for lookup {lookup_kwarg!r}:",
|
||||
)
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
@@ -11,17 +11,17 @@ from mypy.types import AnyType, Instance
|
||||
from mypy.types import Type as MypyType
|
||||
from mypy.types import TypeOfAny
|
||||
|
||||
from mypy_django_plugin.django.context import (
|
||||
DjangoContext, LookupsAreUnsupported,
|
||||
)
|
||||
from mypy_django_plugin.django.context import DjangoContext, LookupsAreUnsupported
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
|
||||
|
||||
def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]:
|
||||
for base_type in [queryset_type, *queryset_type.type.bases]:
|
||||
if (len(base_type.args)
|
||||
and isinstance(base_type.args[0], Instance)
|
||||
and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)):
|
||||
if (
|
||||
len(base_type.args)
|
||||
and isinstance(base_type.args[0], Instance)
|
||||
and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)
|
||||
):
|
||||
return base_type.args[0]
|
||||
return None
|
||||
|
||||
@@ -31,15 +31,15 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
|
||||
assert isinstance(default_return_type, Instance)
|
||||
|
||||
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
|
||||
if (outer_model_info is None
|
||||
or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)):
|
||||
if outer_model_info is None or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
|
||||
return default_return_type
|
||||
|
||||
return helpers.reparametrize_instance(default_return_type, [Instance(outer_model_info, [])])
|
||||
|
||||
|
||||
def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
|
||||
*, method: str, lookup: str) -> Optional[MypyType]:
|
||||
def get_field_type_from_lookup(
|
||||
ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], *, method: str, lookup: str
|
||||
) -> Optional[MypyType]:
|
||||
try:
|
||||
lookup_field = django_context.resolve_lookup_into_field(model_cls, lookup)
|
||||
except FieldError as exc:
|
||||
@@ -48,20 +48,21 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext
|
||||
except LookupsAreUnsupported:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
if ((isinstance(lookup_field, RelatedField) and lookup_field.column == lookup)
|
||||
or isinstance(lookup_field, ForeignObjectRel)):
|
||||
if (isinstance(lookup_field, RelatedField) and lookup_field.column == lookup) or isinstance(
|
||||
lookup_field, ForeignObjectRel
|
||||
):
|
||||
related_model_cls = django_context.get_field_related_model_cls(lookup_field)
|
||||
if related_model_cls is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
lookup_field = django_context.get_primary_key_field(related_model_cls)
|
||||
|
||||
field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx),
|
||||
lookup_field, method=method)
|
||||
field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx), lookup_field, method=method)
|
||||
return field_get_type
|
||||
|
||||
|
||||
def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
|
||||
flat: bool, named: bool) -> MypyType:
|
||||
def get_values_list_row_type(
|
||||
ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], flat: bool, named: bool
|
||||
) -> MypyType:
|
||||
field_lookups = resolve_field_lookups(ctx.args[0], django_context)
|
||||
if field_lookups is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
@@ -70,17 +71,17 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
|
||||
if len(field_lookups) == 0:
|
||||
if flat:
|
||||
primary_key_field = django_context.get_primary_key_field(model_cls)
|
||||
lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls,
|
||||
lookup=primary_key_field.attname, method='values_list')
|
||||
lookup_type = get_field_type_from_lookup(
|
||||
ctx, django_context, model_cls, lookup=primary_key_field.attname, method="values_list"
|
||||
)
|
||||
assert lookup_type is not None
|
||||
return lookup_type
|
||||
elif named:
|
||||
column_types: 'OrderedDict[str, MypyType]' = OrderedDict()
|
||||
column_types: "OrderedDict[str, MypyType]" = OrderedDict()
|
||||
for field in django_context.get_model_fields(model_cls):
|
||||
column_type = django_context.get_field_get_type(typechecker_api, field,
|
||||
method='values_list')
|
||||
column_type = django_context.get_field_get_type(typechecker_api, field, method="values_list")
|
||||
column_types[field.attname] = column_type
|
||||
return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
|
||||
return helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types)
|
||||
else:
|
||||
# flat=False, named=False, all fields
|
||||
field_lookups = []
|
||||
@@ -93,8 +94,9 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
|
||||
|
||||
column_types = OrderedDict()
|
||||
for field_lookup in field_lookups:
|
||||
lookup_field_type = get_field_type_from_lookup(ctx, django_context, model_cls,
|
||||
lookup=field_lookup, method='values_list')
|
||||
lookup_field_type = get_field_type_from_lookup(
|
||||
ctx, django_context, model_cls, lookup=field_lookup, method="values_list"
|
||||
)
|
||||
if lookup_field_type is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
column_types[field_lookup] = lookup_field_type
|
||||
@@ -103,7 +105,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
|
||||
assert len(column_types) == 1
|
||||
row_type = next(iter(column_types.values()))
|
||||
elif named:
|
||||
row_type = helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
|
||||
row_type = helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types)
|
||||
else:
|
||||
row_type = helpers.make_tuple(typechecker_api, list(column_types.values()))
|
||||
|
||||
@@ -123,13 +125,13 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
flat_expr = helpers.get_call_argument_by_name(ctx, 'flat')
|
||||
flat_expr = helpers.get_call_argument_by_name(ctx, "flat")
|
||||
if flat_expr is not None and isinstance(flat_expr, NameExpr):
|
||||
flat = helpers.parse_bool(flat_expr)
|
||||
else:
|
||||
flat = False
|
||||
|
||||
named_expr = helpers.get_call_argument_by_name(ctx, 'named')
|
||||
named_expr = helpers.get_call_argument_by_name(ctx, "named")
|
||||
if named_expr is not None and isinstance(named_expr, NameExpr):
|
||||
named = helpers.parse_bool(named_expr)
|
||||
else:
|
||||
@@ -143,8 +145,7 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
|
||||
flat = flat or False
|
||||
named = named or False
|
||||
|
||||
row_type = get_values_list_row_type(ctx, django_context, model_cls,
|
||||
flat=flat, named=named)
|
||||
row_type = get_values_list_row_type(ctx, django_context, model_cls, flat=flat, named=named)
|
||||
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
|
||||
|
||||
|
||||
@@ -179,10 +180,11 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
|
||||
for field in django_context.get_model_fields(model_cls):
|
||||
field_lookups.append(field.attname)
|
||||
|
||||
column_types: 'OrderedDict[str, MypyType]' = OrderedDict()
|
||||
column_types: "OrderedDict[str, MypyType]" = OrderedDict()
|
||||
for field_lookup in field_lookups:
|
||||
field_lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls,
|
||||
lookup=field_lookup, method='values')
|
||||
field_lookup_type = get_field_type_from_lookup(
|
||||
ctx, django_context, model_cls, lookup=field_lookup, method="values"
|
||||
)
|
||||
if field_lookup_type is None:
|
||||
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
|
||||
|
||||
|
||||
@@ -13,8 +13,7 @@ def get_user_model_hook(ctx: FunctionContext, django_context: DjangoContext) ->
|
||||
model_cls = django_context.apps_registry.get_model(auth_user_model)
|
||||
model_cls_fullname = helpers.get_class_fullname(model_cls)
|
||||
|
||||
model_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx),
|
||||
model_cls_fullname)
|
||||
model_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx), model_cls_fullname)
|
||||
if model_info is None:
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
|
||||
@@ -32,7 +31,7 @@ def get_type_of_settings_attribute(ctx: AttributeContext, django_context: Django
|
||||
|
||||
# first look for the setting in the project settings file, then global settings
|
||||
settings_module = typechecker_api.modules.get(django_context.django_settings_module)
|
||||
global_settings_module = typechecker_api.modules.get('django.conf.global_settings')
|
||||
global_settings_module = typechecker_api.modules.get("django.conf.global_settings")
|
||||
for module in [settings_module, global_settings_module]:
|
||||
if module is not None:
|
||||
sym = module.names.get(setting_name)
|
||||
|
||||
Reference in New Issue
Block a user