mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-10 14:01:56 +08:00
Compare commits
1 Commits
v1.2.0
...
readme-fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
584d4e8911 |
@@ -52,6 +52,12 @@ django_settings_module = mysettings
|
||||
|
||||
Where `mysettings` is a value of `DJANGO_SETTINGS_MODULE` (with or without quotes)
|
||||
|
||||
You might also need to explicitly tweak your `PYTHONPATH` the very same way `django` does it internally in case you have troubles with mypy / django plugin not finding your settings module. Try adding the root path of your project to your `PYTHONPATH` environment variable like so:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=${PYTHONPATH}:${PWD}
|
||||
```
|
||||
|
||||
Current implementation uses Django runtime to extract models information, so it will crash, if your installed apps `models.py` is not correct. For this same reason, you cannot use `reveal_type` inside global scope of any Python file that will be executed for `django.setup()`.
|
||||
|
||||
In other words, if your `manage.py runserver` crashes, mypy will crash too.
|
||||
|
||||
@@ -63,7 +63,7 @@ class BaseCommand:
|
||||
fail_level: int = ...,
|
||||
) -> None: ...
|
||||
def check_migrations(self) -> None: ...
|
||||
def handle(self, *args: Any, **options: Any) -> Optional[str]: ...
|
||||
def handle(self, *args: Any, **options: Any) -> None: ...
|
||||
|
||||
class AppCommand(BaseCommand):
|
||||
missing_args_message: str = ...
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
@@ -49,9 +48,6 @@ def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']:
|
||||
with temp_environ():
|
||||
os.environ['DJANGO_SETTINGS_MODULE'] = settings_module
|
||||
|
||||
# add current directory to sys.path
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
def noop_class_getitem(cls, key):
|
||||
return cls
|
||||
|
||||
@@ -77,12 +73,176 @@ def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']:
|
||||
return apps, settings
|
||||
|
||||
|
||||
class DjangoFieldsContext:
|
||||
def __init__(self, django_context: 'DjangoContext') -> None:
|
||||
self.django_context = django_context
|
||||
|
||||
def get_attname(self, field: Field) -> str:
|
||||
attname = field.attname
|
||||
return attname
|
||||
|
||||
def get_field_nullability(self, field: Field, method: Optional[str]) -> bool:
|
||||
nullable = field.null
|
||||
if not nullable and isinstance(field, CharField) and field.blank:
|
||||
return True
|
||||
if method == '__init__':
|
||||
if field.primary_key or isinstance(field, ForeignKey):
|
||||
return True
|
||||
if method == 'create':
|
||||
if isinstance(field, AutoField):
|
||||
return True
|
||||
if field.has_default():
|
||||
return True
|
||||
return nullable
|
||||
|
||||
def get_field_set_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType:
|
||||
""" Get a type of __set__ for this specific Django field. """
|
||||
target_field = field
|
||||
if isinstance(field, ForeignKey):
|
||||
target_field = field.target_field
|
||||
|
||||
field_info = helpers.lookup_class_typeinfo(api, target_field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
|
||||
is_nullable=self.get_field_nullability(field, method))
|
||||
if isinstance(target_field, ArrayField):
|
||||
argument_field_type = self.get_field_set_type(api, target_field.base_field, method=method)
|
||||
field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type)
|
||||
return field_set_type
|
||||
|
||||
def get_field_get_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType:
|
||||
""" Get a type of __get__ for this specific Django field. """
|
||||
field_info = helpers.lookup_class_typeinfo(api, field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
|
||||
is_nullable = self.get_field_nullability(field, method)
|
||||
if isinstance(field, RelatedField):
|
||||
related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
|
||||
|
||||
if method == 'values':
|
||||
primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
|
||||
return self.get_field_get_type(api, primary_key_field, method=method)
|
||||
|
||||
model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
|
||||
if model_info is None:
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
|
||||
return Instance(model_info, [])
|
||||
else:
|
||||
return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
||||
is_nullable=is_nullable)
|
||||
|
||||
def get_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) -> Type[Model]:
|
||||
if isinstance(field, RelatedField):
|
||||
related_model_cls = field.remote_field.model
|
||||
else:
|
||||
related_model_cls = field.field.model
|
||||
|
||||
if isinstance(related_model_cls, str):
|
||||
if related_model_cls == 'self':
|
||||
# same model
|
||||
related_model_cls = field.model
|
||||
elif '.' not in related_model_cls:
|
||||
# same file model
|
||||
related_model_fullname = field.model.__module__ + '.' + related_model_cls
|
||||
related_model_cls = self.django_context.get_model_class_by_fullname(related_model_fullname)
|
||||
else:
|
||||
related_model_cls = self.django_context.apps_registry.get_model(related_model_cls)
|
||||
|
||||
return related_model_cls
|
||||
|
||||
|
||||
class LookupsAreUnsupported(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DjangoLookupsContext:
|
||||
def __init__(self, django_context: 'DjangoContext'):
|
||||
self.django_context = django_context
|
||||
|
||||
def _resolve_field_from_parts(self, field_parts: Iterable[str], model_cls: Type[Model]) -> Field:
|
||||
currently_observed_model = model_cls
|
||||
field = None
|
||||
for field_part in field_parts:
|
||||
if field_part == 'pk':
|
||||
field = self.django_context.get_primary_key_field(currently_observed_model)
|
||||
continue
|
||||
|
||||
field = currently_observed_model._meta.get_field(field_part)
|
||||
if isinstance(field, RelatedField):
|
||||
currently_observed_model = field.related_model
|
||||
model_name = currently_observed_model._meta.model_name
|
||||
if (model_name is not None
|
||||
and field_part == (model_name + '_id')):
|
||||
field = self.django_context.get_primary_key_field(currently_observed_model)
|
||||
|
||||
if isinstance(field, ForeignObjectRel):
|
||||
currently_observed_model = field.related_model
|
||||
|
||||
assert field is not None
|
||||
return field
|
||||
|
||||
def resolve_lookup_info_field(self, model_cls: Type[Model], lookup: str) -> Field:
|
||||
query = Query(model_cls)
|
||||
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
|
||||
if lookup_parts:
|
||||
raise LookupsAreUnsupported()
|
||||
|
||||
return self._resolve_field_from_parts(field_parts, model_cls)
|
||||
|
||||
def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType:
|
||||
query = Query(model_cls)
|
||||
try:
|
||||
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
|
||||
if is_expression:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
except FieldError as exc:
|
||||
ctx.api.fail(exc.args[0], ctx.context)
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
field = self._resolve_field_from_parts(field_parts, model_cls)
|
||||
|
||||
lookup_cls = None
|
||||
if lookup_parts:
|
||||
lookup = lookup_parts[-1]
|
||||
lookup_cls = field.get_lookup(lookup)
|
||||
if lookup_cls is None:
|
||||
# unknown lookup
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
if lookup_cls is None or isinstance(lookup_cls, Exact):
|
||||
return self.django_context.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field)
|
||||
|
||||
assert lookup_cls is not None
|
||||
|
||||
lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls)
|
||||
if lookup_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
for lookup_base in helpers.iter_bases(lookup_info):
|
||||
if lookup_base.args and isinstance(lookup_base.args[0], Instance):
|
||||
lookup_type: MypyType = lookup_base.args[0]
|
||||
# if it's Field, consider lookup_type a __get__ of current field
|
||||
if (isinstance(lookup_type, Instance)
|
||||
and lookup_type.type.fullname() == fullnames.FIELD_FULLNAME):
|
||||
field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
||||
is_nullable=field.null)
|
||||
return lookup_type
|
||||
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
|
||||
class DjangoContext:
|
||||
def __init__(self, django_settings_module: str) -> None:
|
||||
self.fields_context = DjangoFieldsContext(self)
|
||||
self.lookups_context = DjangoLookupsContext(self)
|
||||
|
||||
self.django_settings_module = django_settings_module
|
||||
|
||||
apps, settings = initialize_django(self.django_settings_module)
|
||||
@@ -128,7 +288,7 @@ class DjangoContext:
|
||||
if isinstance(field, (RelatedField, ForeignObjectRel)):
|
||||
related_model_cls = field.related_model
|
||||
primary_key_field = self.get_primary_key_field(related_model_cls)
|
||||
primary_key_type = self.get_field_get_type(api, primary_key_field, method='init')
|
||||
primary_key_type = self.fields_context.get_field_get_type(api, primary_key_field, method='init')
|
||||
|
||||
rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
|
||||
if rel_model_info is None:
|
||||
@@ -159,13 +319,13 @@ class DjangoContext:
|
||||
# add pk if not abstract=True
|
||||
if not model_cls._meta.abstract:
|
||||
primary_key_field = self.get_primary_key_field(model_cls)
|
||||
field_set_type = self.get_field_set_type(api, primary_key_field, method=method)
|
||||
field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method)
|
||||
expected_types['pk'] = field_set_type
|
||||
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, Field):
|
||||
field_name = field.attname
|
||||
field_set_type = self.get_field_set_type(api, field, method=method)
|
||||
field_set_type = self.fields_context.get_field_set_type(api, field, method=method)
|
||||
expected_types[field_name] = field_set_type
|
||||
|
||||
if isinstance(field, ForeignKey):
|
||||
@@ -176,11 +336,7 @@ class DjangoContext:
|
||||
expected_types[field_name] = AnyType(TypeOfAny.unannotated)
|
||||
continue
|
||||
|
||||
related_model = self.get_field_related_model_cls(field)
|
||||
if related_model is None:
|
||||
expected_types[field_name] = AnyType(TypeOfAny.from_error)
|
||||
continue
|
||||
|
||||
related_model = self.fields_context.get_related_model_cls(field)
|
||||
if related_model._meta.proxy_for_model is not None:
|
||||
related_model = related_model._meta.proxy_for_model
|
||||
|
||||
@@ -189,7 +345,7 @@ class DjangoContext:
|
||||
expected_types[field_name] = AnyType(TypeOfAny.unannotated)
|
||||
continue
|
||||
|
||||
is_nullable = self.get_field_nullability(field, method)
|
||||
is_nullable = self.fields_context.get_field_nullability(field, method)
|
||||
foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info,
|
||||
'_pyi_private_set_type',
|
||||
is_nullable=is_nullable)
|
||||
@@ -209,173 +365,13 @@ class DjangoContext:
|
||||
return expected_types
|
||||
|
||||
@cached_property
|
||||
def all_registered_model_classes(self) -> Set[Type[models.Model]]:
|
||||
def model_base_classes(self) -> Set[str]:
|
||||
model_classes = self.apps_registry.get_models()
|
||||
|
||||
all_model_bases = set()
|
||||
for model_cls in model_classes:
|
||||
for base_cls in model_cls.mro():
|
||||
if issubclass(base_cls, models.Model):
|
||||
all_model_bases.add(base_cls)
|
||||
all_model_bases.add(helpers.get_class_fullname(base_cls))
|
||||
|
||||
return all_model_bases
|
||||
|
||||
@cached_property
|
||||
def all_registered_model_class_fullnames(self) -> Set[str]:
|
||||
return {helpers.get_class_fullname(cls) for cls in self.all_registered_model_classes}
|
||||
|
||||
def get_attname(self, field: Field) -> str:
|
||||
attname = field.attname
|
||||
return attname
|
||||
|
||||
def get_field_nullability(self, field: Field, method: Optional[str]) -> bool:
|
||||
nullable = field.null
|
||||
if not nullable and isinstance(field, CharField) and field.blank:
|
||||
return True
|
||||
if method == '__init__':
|
||||
if field.primary_key or isinstance(field, ForeignKey):
|
||||
return True
|
||||
if method == 'create':
|
||||
if isinstance(field, AutoField):
|
||||
return True
|
||||
if field.has_default():
|
||||
return True
|
||||
return nullable
|
||||
|
||||
def get_field_set_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType:
|
||||
""" Get a type of __set__ for this specific Django field. """
|
||||
target_field = field
|
||||
if isinstance(field, ForeignKey):
|
||||
target_field = field.target_field
|
||||
|
||||
field_info = helpers.lookup_class_typeinfo(api, target_field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
|
||||
is_nullable=self.get_field_nullability(field, method))
|
||||
if isinstance(target_field, ArrayField):
|
||||
argument_field_type = self.get_field_set_type(api, target_field.base_field, method=method)
|
||||
field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type)
|
||||
return field_set_type
|
||||
|
||||
def get_field_get_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType:
|
||||
""" Get a type of __get__ for this specific Django field. """
|
||||
field_info = helpers.lookup_class_typeinfo(api, field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
|
||||
is_nullable = self.get_field_nullability(field, method)
|
||||
if isinstance(field, RelatedField):
|
||||
related_model_cls = self.get_field_related_model_cls(field)
|
||||
if related_model_cls is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
if method == 'values':
|
||||
primary_key_field = self.get_primary_key_field(related_model_cls)
|
||||
return self.get_field_get_type(api, primary_key_field, method=method)
|
||||
|
||||
model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
|
||||
if model_info is None:
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
|
||||
return Instance(model_info, [])
|
||||
else:
|
||||
return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
||||
is_nullable=is_nullable)
|
||||
|
||||
def get_field_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) -> Optional[Type[Model]]:
|
||||
if isinstance(field, RelatedField):
|
||||
related_model_cls = field.remote_field.model
|
||||
else:
|
||||
related_model_cls = field.field.model
|
||||
|
||||
if isinstance(related_model_cls, str):
|
||||
if related_model_cls == 'self':
|
||||
# same model
|
||||
related_model_cls = field.model
|
||||
elif '.' not in related_model_cls:
|
||||
# same file model
|
||||
related_model_fullname = field.model.__module__ + '.' + related_model_cls
|
||||
related_model_cls = self.get_model_class_by_fullname(related_model_fullname)
|
||||
else:
|
||||
related_model_cls = self.apps_registry.get_model(related_model_cls)
|
||||
|
||||
return related_model_cls
|
||||
|
||||
def _resolve_field_from_parts(self, field_parts: Iterable[str], model_cls: Type[Model]) -> Field:
|
||||
currently_observed_model = model_cls
|
||||
field = None
|
||||
for field_part in field_parts:
|
||||
if field_part == 'pk':
|
||||
field = self.get_primary_key_field(currently_observed_model)
|
||||
continue
|
||||
|
||||
field = currently_observed_model._meta.get_field(field_part)
|
||||
if isinstance(field, RelatedField):
|
||||
currently_observed_model = field.related_model
|
||||
model_name = currently_observed_model._meta.model_name
|
||||
if (model_name is not None
|
||||
and field_part == (model_name + '_id')):
|
||||
field = self.get_primary_key_field(currently_observed_model)
|
||||
|
||||
if isinstance(field, ForeignObjectRel):
|
||||
currently_observed_model = field.related_model
|
||||
|
||||
assert field is not None
|
||||
return field
|
||||
|
||||
def resolve_lookup_into_field(self, model_cls: Type[Model], lookup: str) -> Field:
|
||||
query = Query(model_cls)
|
||||
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
|
||||
if lookup_parts:
|
||||
raise LookupsAreUnsupported()
|
||||
|
||||
return self._resolve_field_from_parts(field_parts, model_cls)
|
||||
|
||||
def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType:
|
||||
query = Query(model_cls)
|
||||
try:
|
||||
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
|
||||
if is_expression:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
except FieldError as exc:
|
||||
ctx.api.fail(exc.args[0], ctx.context)
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
field = self._resolve_field_from_parts(field_parts, model_cls)
|
||||
|
||||
lookup_cls = None
|
||||
if lookup_parts:
|
||||
lookup = lookup_parts[-1]
|
||||
lookup_cls = field.get_lookup(lookup)
|
||||
if lookup_cls is None:
|
||||
# unknown lookup
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
if lookup_cls is None or isinstance(lookup_cls, Exact):
|
||||
return self.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field)
|
||||
|
||||
assert lookup_cls is not None
|
||||
|
||||
lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls)
|
||||
if lookup_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
for lookup_base in helpers.iter_bases(lookup_info):
|
||||
if lookup_base.args and isinstance(lookup_base.args[0], Instance):
|
||||
lookup_type: MypyType = lookup_base.args[0]
|
||||
# if it's Field, consider lookup_type a __get__ of current field
|
||||
if (isinstance(lookup_type, Instance)
|
||||
and lookup_type.type.fullname() == fullnames.FIELD_FULLNAME):
|
||||
field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
||||
is_nullable=field.null)
|
||||
return lookup_type
|
||||
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
def resolve_f_expression_type(self, f_expression_type: Instance) -> MypyType:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
@@ -37,5 +37,3 @@ RELATED_FIELDS_CLASSES = {
|
||||
MIGRATION_CLASS_FULLNAME = 'django.db.migrations.migration.Migration'
|
||||
OPTIONS_CLASS_FULLNAME = 'django.db.models.options.Options'
|
||||
HTTPREQUEST_CLASS_FULLNAME = 'django.http.request.HttpRequest'
|
||||
|
||||
F_EXPRESSION_FULLNAME = 'django.db.models.expressions.F'
|
||||
|
||||
@@ -282,25 +282,5 @@ def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionCont
|
||||
|
||||
|
||||
def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool:
|
||||
return (info.fullname() in django_context.all_registered_model_class_fullnames
|
||||
return (info.fullname() in django_context.model_base_classes
|
||||
or info.has_base(fullnames.MODEL_CLASS_FULLNAME))
|
||||
|
||||
|
||||
def check_types_compatible(ctx: Union[FunctionContext, MethodContext],
|
||||
*, expected_type: MypyType, actual_type: MypyType, error_message: str) -> None:
|
||||
api = get_typechecker_api(ctx)
|
||||
api.check_subtype(actual_type, expected_type,
|
||||
ctx.context, error_message,
|
||||
'got', 'expected')
|
||||
|
||||
|
||||
def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None:
|
||||
# type=: type of the variable itself
|
||||
var = Var(name=name, type=sym_type)
|
||||
# var.info: type of the object variable is bound to
|
||||
var.info = info
|
||||
var._fullname = info.fullname() + '.' + name
|
||||
var.is_initialized_in_class = True
|
||||
var.is_inferred = True
|
||||
info.names[name] = SymbolTableNode(MDEF, var,
|
||||
plugin_generated=True)
|
||||
|
||||
@@ -11,7 +11,6 @@ from mypy.plugin import (
|
||||
)
|
||||
from mypy.types import Type as MypyType
|
||||
|
||||
import mypy_django_plugin.transformers.orm_lookups
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
from mypy_django_plugin.transformers import (
|
||||
@@ -149,15 +148,13 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
# forward relations
|
||||
for field in self.django_context.get_model_fields(model_class):
|
||||
if isinstance(field, RelatedField):
|
||||
related_model_cls = self.django_context.get_field_related_model_cls(field)
|
||||
if related_model_cls is None:
|
||||
continue
|
||||
related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
|
||||
related_model_module = related_model_cls.__module__
|
||||
if related_model_module != file.fullname():
|
||||
deps.add(self._new_dependency(related_model_module))
|
||||
# reverse relations
|
||||
for relation in model_class._meta.related_objects:
|
||||
related_model_cls = self.django_context.get_field_related_model_cls(relation)
|
||||
related_model_cls = self.django_context.fields_context.get_related_model_cls(relation)
|
||||
related_model_module = related_model_cls.__module__
|
||||
if related_model_module != file.fullname():
|
||||
deps.add(self._new_dependency(related_model_module))
|
||||
@@ -213,13 +210,12 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
if class_fullname in manager_classes and method_name == 'create':
|
||||
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
|
||||
if class_fullname in manager_classes and method_name in {'filter', 'get', 'exclude'}:
|
||||
return partial(mypy_django_plugin.transformers.orm_lookups.typecheck_queryset_filter,
|
||||
django_context=self.django_context)
|
||||
return partial(init_create.typecheck_queryset_filter, django_context=self.django_context)
|
||||
return None
|
||||
|
||||
def get_base_class_hook(self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
if (fullname in self.django_context.all_registered_model_class_fullnames
|
||||
if (fullname in self.django_context.model_base_classes
|
||||
or fullname in self._get_current_model_bases()):
|
||||
return partial(transform_model_class, django_context=self.django_context)
|
||||
|
||||
|
||||
@@ -37,14 +37,6 @@ def _get_current_field_from_assignment(ctx: FunctionContext, django_context: Dja
|
||||
return current_field
|
||||
|
||||
|
||||
def reparametrize_related_field_type(related_field_type: Instance, set_type, get_type) -> Instance:
|
||||
args = [
|
||||
helpers.convert_any_to_type(related_field_type.args[0], set_type),
|
||||
helpers.convert_any_to_type(related_field_type.args[1], get_type),
|
||||
]
|
||||
return helpers.reparametrize_instance(related_field_type, new_args=args)
|
||||
|
||||
|
||||
def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
|
||||
current_field = _get_current_field_from_assignment(ctx, django_context)
|
||||
if current_field is None:
|
||||
@@ -52,28 +44,7 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
|
||||
|
||||
assert isinstance(current_field, RelatedField)
|
||||
|
||||
related_model_cls = django_context.get_field_related_model_cls(current_field)
|
||||
if related_model_cls is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
default_related_field_type = set_descriptor_types_for_field(ctx)
|
||||
|
||||
# self reference with abstract=True on the model where ForeignKey is defined
|
||||
current_model_cls = current_field.model
|
||||
if (current_model_cls._meta.abstract
|
||||
and current_model_cls == related_model_cls):
|
||||
# for all derived non-abstract classes, set variable with this name to
|
||||
# __get__/__set__ of ForeignKey of derived model
|
||||
for model_cls in django_context.all_registered_model_classes:
|
||||
if issubclass(model_cls, current_model_cls) and not model_cls._meta.abstract:
|
||||
derived_model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls)
|
||||
if derived_model_info is not None:
|
||||
fk_ref_type = Instance(derived_model_info, [])
|
||||
derived_fk_type = reparametrize_related_field_type(default_related_field_type,
|
||||
set_type=fk_ref_type, get_type=fk_ref_type)
|
||||
helpers.add_new_sym_for_info(derived_model_info,
|
||||
name=current_field.name,
|
||||
sym_type=derived_fk_type)
|
||||
related_model_cls = django_context.fields_context.get_related_model_cls(current_field)
|
||||
|
||||
related_model = related_model_cls
|
||||
related_model_to_set = related_model_cls
|
||||
@@ -96,10 +67,13 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
|
||||
else:
|
||||
related_model_to_set_type = Instance(related_model_to_set_info, []) # type: ignore
|
||||
|
||||
default_related_field_type = set_descriptor_types_for_field(ctx)
|
||||
# replace Any with referred_to_type
|
||||
return reparametrize_related_field_type(default_related_field_type,
|
||||
set_type=related_model_to_set_type,
|
||||
get_type=related_model_type)
|
||||
args = [
|
||||
helpers.convert_any_to_type(default_related_field_type.args[0], related_model_to_set_type),
|
||||
helpers.convert_any_to_type(default_related_field_type.args[1], related_model_type),
|
||||
]
|
||||
return helpers.reparametrize_instance(default_related_field_type, new_args=args)
|
||||
|
||||
|
||||
def get_field_descriptor_types(field_info: TypeInfo, is_nullable: bool) -> Tuple[MypyType, MypyType]:
|
||||
|
||||
@@ -6,7 +6,7 @@ from mypy.types import Instance
|
||||
from mypy.types import Type as MypyType
|
||||
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import helpers
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
|
||||
|
||||
def get_actual_types(ctx: Union[MethodContext, FunctionContext],
|
||||
@@ -30,6 +30,12 @@ def get_actual_types(ctx: Union[MethodContext, FunctionContext],
|
||||
return actual_types
|
||||
|
||||
|
||||
def check_types_compatible(ctx, *, expected_type, actual_type, error_message):
|
||||
ctx.api.check_subtype(actual_type, expected_type,
|
||||
ctx.context, error_message,
|
||||
'got', 'expected')
|
||||
|
||||
|
||||
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
|
||||
model_cls: Type[Model], method: str) -> MypyType:
|
||||
typechecker_api = helpers.get_typechecker_api(ctx)
|
||||
@@ -42,11 +48,11 @@ def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_co
|
||||
model_cls.__name__),
|
||||
ctx.context)
|
||||
continue
|
||||
helpers.check_types_compatible(ctx,
|
||||
expected_type=expected_types[actual_name],
|
||||
actual_type=actual_type,
|
||||
error_message='Incompatible type for "{}" of "{}"'.format(actual_name,
|
||||
model_cls.__name__))
|
||||
check_types_compatible(ctx,
|
||||
expected_type=expected_types[actual_name],
|
||||
actual_type=actual_type,
|
||||
error_message='Incompatible type for "{}" of "{}"'.format(actual_name,
|
||||
model_cls.__name__))
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
@@ -73,3 +79,40 @@ def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: Djan
|
||||
return ctx.default_return_type
|
||||
|
||||
return typecheck_model_method(ctx, django_context, model_cls, 'create')
|
||||
|
||||
|
||||
def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||
lookup_kwargs = ctx.arg_names[1]
|
||||
provided_lookup_types = ctx.arg_types[1]
|
||||
|
||||
assert isinstance(ctx.type, Instance)
|
||||
|
||||
if not ctx.type.args or not isinstance(ctx.type.args[0], Instance):
|
||||
return ctx.default_return_type
|
||||
|
||||
model_cls_fullname = ctx.type.args[0].type.fullname()
|
||||
model_cls = django_context.get_model_class_by_fullname(model_cls_fullname)
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
for lookup_kwarg, provided_type in zip(lookup_kwargs, provided_lookup_types):
|
||||
if lookup_kwarg is None:
|
||||
continue
|
||||
# Combinables are not supported yet
|
||||
if (isinstance(provided_type, Instance)
|
||||
and provided_type.type.has_base('django.db.models.expressions.Combinable')):
|
||||
continue
|
||||
|
||||
lookup_type = django_context.lookups_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg)
|
||||
# Managers as provided_type is not supported yet
|
||||
if (isinstance(provided_type, Instance)
|
||||
and helpers.has_any_of_bases(provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME,
|
||||
fullnames.QUERYSET_CLASS_FULLNAME))):
|
||||
return ctx.default_return_type
|
||||
|
||||
check_types_compatible(ctx,
|
||||
expected_type=lookup_type,
|
||||
actual_type=provided_type,
|
||||
error_message=f'Incompatible type for lookup {lookup_kwarg!r}:')
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
@@ -7,12 +7,12 @@ from django.db.models.fields.related import ForeignKey
|
||||
from django.db.models.fields.reverse_related import (
|
||||
ManyToManyRel, ManyToOneRel, OneToOneRel,
|
||||
)
|
||||
from mypy.nodes import ARG_STAR2, Argument, Context, TypeInfo, Var
|
||||
from mypy.nodes import (
|
||||
ARG_STAR2, MDEF, Argument, SymbolTableNode, TypeInfo, Var,
|
||||
)
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.plugins import common
|
||||
from mypy.types import AnyType, Instance
|
||||
from mypy.types import Type as MypyType
|
||||
from mypy.types import TypeOfAny
|
||||
from mypy.types import AnyType, Instance, TypeOfAny
|
||||
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
@@ -38,7 +38,7 @@ class ModelClassInitializer:
|
||||
field_info = self.lookup_typeinfo_or_incomplete_defn_error(fullname)
|
||||
return field_info
|
||||
|
||||
def create_new_var(self, name: str, typ: MypyType) -> Var:
|
||||
def create_new_var(self, name: str, typ: Instance) -> Var:
|
||||
# type=: type of the variable itself
|
||||
var = Var(name=name, type=typ)
|
||||
# var.info: type of the object variable is bound to
|
||||
@@ -48,10 +48,9 @@ class ModelClassInitializer:
|
||||
var.is_inferred = True
|
||||
return var
|
||||
|
||||
def add_new_node_to_model_class(self, name: str, typ: MypyType) -> None:
|
||||
helpers.add_new_sym_for_info(self.model_classdef.info,
|
||||
name=name,
|
||||
sym_type=typ)
|
||||
def add_new_node_to_model_class(self, name: str, typ: Instance) -> None:
|
||||
var = self.create_new_var(name, typ)
|
||||
self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True)
|
||||
|
||||
def run(self) -> None:
|
||||
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname)
|
||||
@@ -100,25 +99,10 @@ class AddRelatedModelsId(ModelClassInitializer):
|
||||
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, ForeignKey):
|
||||
related_model_cls = self.django_context.get_field_related_model_cls(field)
|
||||
if related_model_cls is None:
|
||||
error_context: Context = self.ctx.cls
|
||||
field_sym = self.ctx.cls.info.get(field.name)
|
||||
if field_sym is not None and field_sym.node is not None:
|
||||
error_context = field_sym.node
|
||||
self.api.fail(f'Cannot find model {field.related_model!r} '
|
||||
f'referenced in field {field.name!r} ',
|
||||
ctx=error_context)
|
||||
self.add_new_node_to_model_class(field.attname,
|
||||
AnyType(TypeOfAny.explicit))
|
||||
continue
|
||||
|
||||
if related_model_cls._meta.abstract:
|
||||
continue
|
||||
|
||||
related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
|
||||
rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
|
||||
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__)
|
||||
is_nullable = self.django_context.get_field_nullability(field, None)
|
||||
is_nullable = self.django_context.fields_context.get_field_nullability(field, None)
|
||||
set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
|
||||
self.add_new_node_to_model_class(field.attname,
|
||||
Instance(field_info, [set_type, get_type]))
|
||||
@@ -178,11 +162,9 @@ class AddRelatedManagers(ModelClassInitializer):
|
||||
# no reverse accessor
|
||||
continue
|
||||
|
||||
related_model_cls = self.django_context.get_field_related_model_cls(relation)
|
||||
if related_model_cls is None:
|
||||
continue
|
||||
|
||||
related_model_cls = self.django_context.fields_context.get_related_model_cls(relation)
|
||||
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls)
|
||||
|
||||
if isinstance(relation, OneToOneRel):
|
||||
self.add_new_node_to_model_class(attname, Instance(related_model_info, []))
|
||||
continue
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
from mypy.plugin import MethodContext
|
||||
from mypy.types import AnyType, Instance
|
||||
from mypy.types import Type as MypyType
|
||||
from mypy.types import TypeOfAny
|
||||
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
|
||||
|
||||
def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||
lookup_kwargs = ctx.arg_names[1]
|
||||
provided_lookup_types = ctx.arg_types[1]
|
||||
|
||||
assert isinstance(ctx.type, Instance)
|
||||
|
||||
if not ctx.type.args or not isinstance(ctx.type.args[0], Instance):
|
||||
return ctx.default_return_type
|
||||
|
||||
model_cls_fullname = ctx.type.args[0].type.fullname()
|
||||
model_cls = django_context.get_model_class_by_fullname(model_cls_fullname)
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
for lookup_kwarg, provided_type in zip(lookup_kwargs, provided_lookup_types):
|
||||
if lookup_kwarg is None:
|
||||
continue
|
||||
if (isinstance(provided_type, Instance)
|
||||
and provided_type.type.has_base('django.db.models.expressions.Combinable')):
|
||||
provided_type = resolve_combinable_type(provided_type, django_context)
|
||||
|
||||
lookup_type = django_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg)
|
||||
# Managers as provided_type is not supported yet
|
||||
if (isinstance(provided_type, Instance)
|
||||
and helpers.has_any_of_bases(provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME,
|
||||
fullnames.QUERYSET_CLASS_FULLNAME))):
|
||||
return ctx.default_return_type
|
||||
|
||||
helpers.check_types_compatible(ctx,
|
||||
expected_type=lookup_type,
|
||||
actual_type=provided_type,
|
||||
error_message=f'Incompatible type for lookup {lookup_kwarg!r}:')
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def resolve_combinable_type(combinable_type: Instance, django_context: DjangoContext) -> MypyType:
|
||||
if combinable_type.type.fullname() != fullnames.F_EXPRESSION_FULLNAME:
|
||||
# Combinables aside from F expressions are unsupported
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
return django_context.resolve_f_expression_type(combinable_type)
|
||||
@@ -40,7 +40,7 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
|
||||
def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
|
||||
*, method: str, lookup: str) -> Optional[MypyType]:
|
||||
try:
|
||||
lookup_field = django_context.resolve_lookup_into_field(model_cls, lookup)
|
||||
lookup_field = django_context.lookups_context.resolve_lookup_info_field(model_cls, lookup)
|
||||
except FieldError as exc:
|
||||
ctx.api.fail(exc.args[0], ctx.context)
|
||||
return None
|
||||
@@ -48,13 +48,11 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup:
|
||||
related_model_cls = django_context.get_field_related_model_cls(lookup_field)
|
||||
if related_model_cls is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
related_model_cls = django_context.fields_context.get_related_model_cls(lookup_field)
|
||||
lookup_field = django_context.get_primary_key_field(related_model_cls)
|
||||
|
||||
field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx),
|
||||
lookup_field, method=method)
|
||||
field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx),
|
||||
lookup_field, method=method)
|
||||
return field_get_type
|
||||
|
||||
|
||||
@@ -75,8 +73,8 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
|
||||
elif named:
|
||||
column_types: 'OrderedDict[str, MypyType]' = OrderedDict()
|
||||
for field in django_context.get_model_fields(model_cls):
|
||||
column_type = django_context.get_field_get_type(typechecker_api, field,
|
||||
method='values_list')
|
||||
column_type = django_context.fields_context.get_field_get_type(typechecker_api, field,
|
||||
method='values_list')
|
||||
column_types[field.attname] = column_type
|
||||
return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
|
||||
else:
|
||||
|
||||
4
setup.py
4
setup.py
@@ -21,14 +21,14 @@ with open('README.md', 'r') as f:
|
||||
readme = f.read()
|
||||
|
||||
dependencies = [
|
||||
'mypy>=0.730',
|
||||
'mypy>=0.730,<0.740',
|
||||
'typing-extensions',
|
||||
'django',
|
||||
]
|
||||
|
||||
setup(
|
||||
name="django-stubs",
|
||||
version="1.2.0",
|
||||
version="1.1.0",
|
||||
description='Mypy stubs for Django',
|
||||
long_description=readme,
|
||||
long_description_content_type='text/markdown',
|
||||
|
||||
@@ -387,52 +387,22 @@
|
||||
class Book2(models.Model):
|
||||
publisher = models.ForeignKey(to=Publisher2, on_delete=models.CASCADE)
|
||||
|
||||
- case: if_model_is_defined_as_name_of_the_class_look_for_it_in_the_same_app
|
||||
- case: if_model_is_defined_as_name_of_the_class_look_for_it_in_the_same_file
|
||||
main: |
|
||||
from myapp.models import Book
|
||||
reveal_type(Book().publisher) # N: Revealed type is 'myapp.models.publisher.Publisher*'
|
||||
reveal_type(Book().publisher) # N: Revealed type is 'myapp.models.Publisher*'
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models/__init__.py
|
||||
content: |
|
||||
from .publisher import Publisher
|
||||
from .book import Book
|
||||
- path: myapp/models/publisher.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
class Publisher(models.Model):
|
||||
pass
|
||||
- path: myapp/models/book.py
|
||||
content: |
|
||||
from django.db import models
|
||||
class Book(models.Model):
|
||||
publisher = models.ForeignKey(to='Publisher', on_delete=models.CASCADE)
|
||||
|
||||
|
||||
- case: fail_if_no_model_in_the_same_app_models_init_py
|
||||
main: |
|
||||
from myapp.models import Book
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models/__init__.py
|
||||
content: |
|
||||
from .book import Book
|
||||
- path: myapp/models/publisher.py
|
||||
content: |
|
||||
from django.db import models
|
||||
class Publisher(models.Model):
|
||||
pass
|
||||
- path: myapp/models/book.py
|
||||
content: |
|
||||
from django.db import models
|
||||
class Book(models.Model):
|
||||
publisher = models.ForeignKey(to='Publisher', on_delete=models.CASCADE) # E: Cannot find model 'Publisher' referenced in field 'publisher'
|
||||
|
||||
|
||||
- case: test_foreign_key_field_without_backwards_relation
|
||||
main: |
|
||||
from myapp.models import Book, Publisher
|
||||
@@ -623,28 +593,4 @@
|
||||
class TransactionLog(models.Model):
|
||||
transaction = models.ForeignKey(Transaction, on_delete=models.CASCADE)
|
||||
|
||||
Transaction().test()
|
||||
|
||||
|
||||
- case: resolve_primary_keys_for_foreign_keys_with_abstract_self_model
|
||||
main: |
|
||||
from myapp.models import User
|
||||
reveal_type(User().parent) # N: Revealed type is 'myapp.models.User*'
|
||||
reveal_type(User().parent_id) # N: Revealed type is 'builtins.int*'
|
||||
|
||||
reveal_type(User().parent2) # N: Revealed type is 'Union[myapp.models.User, None]'
|
||||
reveal_type(User().parent2_id) # N: Revealed type is 'Union[builtins.int, None]'
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
class AbstractUser(models.Model):
|
||||
parent = models.ForeignKey('self', on_delete=models.CASCADE)
|
||||
parent2 = models.ForeignKey('self', null=True, on_delete=models.CASCADE)
|
||||
class Meta:
|
||||
abstract = True
|
||||
class User(AbstractUser):
|
||||
pass
|
||||
Transaction().test()
|
||||
@@ -212,45 +212,4 @@
|
||||
class User(models.Model):
|
||||
pass
|
||||
class Profile(models.Model):
|
||||
user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile')
|
||||
|
||||
|
||||
# TODO
|
||||
- case: f_expression_simple_case
|
||||
main: |
|
||||
from myapp.models import User
|
||||
from django.db import models
|
||||
User.objects.filter(username=models.F('username2'))
|
||||
User.objects.filter(username=models.F('age'))
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
|
||||
class User(models.Model):
|
||||
username = models.TextField()
|
||||
username2 = models.TextField()
|
||||
|
||||
age = models.IntegerField()
|
||||
|
||||
|
||||
# TODO
|
||||
- case: f_expression_with_expression_math_is_not_supported
|
||||
main: |
|
||||
from myapp.models import User
|
||||
from django.db import models
|
||||
User.objects.filter(username=models.F('username2') + 'hello')
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
class User(models.Model):
|
||||
username = models.TextField()
|
||||
username2 = models.TextField()
|
||||
age = models.IntegerField()
|
||||
user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile')
|
||||
Reference in New Issue
Block a user