mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-07 04:34:29 +08:00
Reorganize code a bit, add current directory to sys.path (#198)
* reorganize code a bit * add current directory to sys.path * remove PYTHONPATH mention from the docs * linting
This commit is contained in:
@@ -52,12 +52,6 @@ django_settings_module = mysettings
|
||||
|
||||
Where `mysettings` is a value of `DJANGO_SETTINGS_MODULE` (with or without quotes)
|
||||
|
||||
You might also need to explicitly tweak your `PYTHONPATH` the very same way `django` does it internally in case you have troubles with mypy / django plugin not finding your settings module. Try adding the root path of your project to your `PYTHONPATH` environment variable like so:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=${PYTHONPATH}:${PWD}
|
||||
```
|
||||
|
||||
Current implementation uses Django runtime to extract models information, so it will crash, if your installed apps `models.py` is not correct. For this same reason, you cannot use `reveal_type` inside global scope of any Python file that will be executed for `django.setup()`.
|
||||
|
||||
In other words, if your `manage.py runserver` crashes, mypy will crash too.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
@@ -48,6 +49,9 @@ def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']:
|
||||
with temp_environ():
|
||||
os.environ['DJANGO_SETTINGS_MODULE'] = settings_module
|
||||
|
||||
# add current directory to sys.path
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
def noop_class_getitem(cls, key):
|
||||
return cls
|
||||
|
||||
@@ -73,176 +77,12 @@ def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']:
|
||||
return apps, settings
|
||||
|
||||
|
||||
class DjangoFieldsContext:
|
||||
def __init__(self, django_context: 'DjangoContext') -> None:
|
||||
self.django_context = django_context
|
||||
|
||||
def get_attname(self, field: Field) -> str:
|
||||
attname = field.attname
|
||||
return attname
|
||||
|
||||
def get_field_nullability(self, field: Field, method: Optional[str]) -> bool:
|
||||
nullable = field.null
|
||||
if not nullable and isinstance(field, CharField) and field.blank:
|
||||
return True
|
||||
if method == '__init__':
|
||||
if field.primary_key or isinstance(field, ForeignKey):
|
||||
return True
|
||||
if method == 'create':
|
||||
if isinstance(field, AutoField):
|
||||
return True
|
||||
if field.has_default():
|
||||
return True
|
||||
return nullable
|
||||
|
||||
def get_field_set_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType:
|
||||
""" Get a type of __set__ for this specific Django field. """
|
||||
target_field = field
|
||||
if isinstance(field, ForeignKey):
|
||||
target_field = field.target_field
|
||||
|
||||
field_info = helpers.lookup_class_typeinfo(api, target_field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
|
||||
is_nullable=self.get_field_nullability(field, method))
|
||||
if isinstance(target_field, ArrayField):
|
||||
argument_field_type = self.get_field_set_type(api, target_field.base_field, method=method)
|
||||
field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type)
|
||||
return field_set_type
|
||||
|
||||
def get_field_get_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType:
|
||||
""" Get a type of __get__ for this specific Django field. """
|
||||
field_info = helpers.lookup_class_typeinfo(api, field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
|
||||
is_nullable = self.get_field_nullability(field, method)
|
||||
if isinstance(field, RelatedField):
|
||||
related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
|
||||
|
||||
if method == 'values':
|
||||
primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
|
||||
return self.get_field_get_type(api, primary_key_field, method=method)
|
||||
|
||||
model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
|
||||
if model_info is None:
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
|
||||
return Instance(model_info, [])
|
||||
else:
|
||||
return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
||||
is_nullable=is_nullable)
|
||||
|
||||
def get_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) -> Type[Model]:
|
||||
if isinstance(field, RelatedField):
|
||||
related_model_cls = field.remote_field.model
|
||||
else:
|
||||
related_model_cls = field.field.model
|
||||
|
||||
if isinstance(related_model_cls, str):
|
||||
if related_model_cls == 'self':
|
||||
# same model
|
||||
related_model_cls = field.model
|
||||
elif '.' not in related_model_cls:
|
||||
# same file model
|
||||
related_model_fullname = field.model.__module__ + '.' + related_model_cls
|
||||
related_model_cls = self.django_context.get_model_class_by_fullname(related_model_fullname)
|
||||
else:
|
||||
related_model_cls = self.django_context.apps_registry.get_model(related_model_cls)
|
||||
|
||||
return related_model_cls
|
||||
|
||||
|
||||
class LookupsAreUnsupported(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DjangoLookupsContext:
|
||||
def __init__(self, django_context: 'DjangoContext'):
|
||||
self.django_context = django_context
|
||||
|
||||
def _resolve_field_from_parts(self, field_parts: Iterable[str], model_cls: Type[Model]) -> Field:
|
||||
currently_observed_model = model_cls
|
||||
field = None
|
||||
for field_part in field_parts:
|
||||
if field_part == 'pk':
|
||||
field = self.django_context.get_primary_key_field(currently_observed_model)
|
||||
continue
|
||||
|
||||
field = currently_observed_model._meta.get_field(field_part)
|
||||
if isinstance(field, RelatedField):
|
||||
currently_observed_model = field.related_model
|
||||
model_name = currently_observed_model._meta.model_name
|
||||
if (model_name is not None
|
||||
and field_part == (model_name + '_id')):
|
||||
field = self.django_context.get_primary_key_field(currently_observed_model)
|
||||
|
||||
if isinstance(field, ForeignObjectRel):
|
||||
currently_observed_model = field.related_model
|
||||
|
||||
assert field is not None
|
||||
return field
|
||||
|
||||
def resolve_lookup_info_field(self, model_cls: Type[Model], lookup: str) -> Field:
|
||||
query = Query(model_cls)
|
||||
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
|
||||
if lookup_parts:
|
||||
raise LookupsAreUnsupported()
|
||||
|
||||
return self._resolve_field_from_parts(field_parts, model_cls)
|
||||
|
||||
def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType:
|
||||
query = Query(model_cls)
|
||||
try:
|
||||
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
|
||||
if is_expression:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
except FieldError as exc:
|
||||
ctx.api.fail(exc.args[0], ctx.context)
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
field = self._resolve_field_from_parts(field_parts, model_cls)
|
||||
|
||||
lookup_cls = None
|
||||
if lookup_parts:
|
||||
lookup = lookup_parts[-1]
|
||||
lookup_cls = field.get_lookup(lookup)
|
||||
if lookup_cls is None:
|
||||
# unknown lookup
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
if lookup_cls is None or isinstance(lookup_cls, Exact):
|
||||
return self.django_context.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field)
|
||||
|
||||
assert lookup_cls is not None
|
||||
|
||||
lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls)
|
||||
if lookup_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
for lookup_base in helpers.iter_bases(lookup_info):
|
||||
if lookup_base.args and isinstance(lookup_base.args[0], Instance):
|
||||
lookup_type: MypyType = lookup_base.args[0]
|
||||
# if it's Field, consider lookup_type a __get__ of current field
|
||||
if (isinstance(lookup_type, Instance)
|
||||
and lookup_type.type.fullname() == fullnames.FIELD_FULLNAME):
|
||||
field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
||||
is_nullable=field.null)
|
||||
return lookup_type
|
||||
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
|
||||
class DjangoContext:
|
||||
def __init__(self, django_settings_module: str) -> None:
|
||||
self.fields_context = DjangoFieldsContext(self)
|
||||
self.lookups_context = DjangoLookupsContext(self)
|
||||
|
||||
self.django_settings_module = django_settings_module
|
||||
|
||||
apps, settings = initialize_django(self.django_settings_module)
|
||||
@@ -288,7 +128,7 @@ class DjangoContext:
|
||||
if isinstance(field, (RelatedField, ForeignObjectRel)):
|
||||
related_model_cls = field.related_model
|
||||
primary_key_field = self.get_primary_key_field(related_model_cls)
|
||||
primary_key_type = self.fields_context.get_field_get_type(api, primary_key_field, method='init')
|
||||
primary_key_type = self.get_field_get_type(api, primary_key_field, method='init')
|
||||
|
||||
rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
|
||||
if rel_model_info is None:
|
||||
@@ -319,13 +159,13 @@ class DjangoContext:
|
||||
# add pk if not abstract=True
|
||||
if not model_cls._meta.abstract:
|
||||
primary_key_field = self.get_primary_key_field(model_cls)
|
||||
field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method)
|
||||
field_set_type = self.get_field_set_type(api, primary_key_field, method=method)
|
||||
expected_types['pk'] = field_set_type
|
||||
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, Field):
|
||||
field_name = field.attname
|
||||
field_set_type = self.fields_context.get_field_set_type(api, field, method=method)
|
||||
field_set_type = self.get_field_set_type(api, field, method=method)
|
||||
expected_types[field_name] = field_set_type
|
||||
|
||||
if isinstance(field, ForeignKey):
|
||||
@@ -336,7 +176,7 @@ class DjangoContext:
|
||||
expected_types[field_name] = AnyType(TypeOfAny.unannotated)
|
||||
continue
|
||||
|
||||
related_model = self.fields_context.get_related_model_cls(field)
|
||||
related_model = self.get_field_related_model_cls(field)
|
||||
if related_model._meta.proxy_for_model is not None:
|
||||
related_model = related_model._meta.proxy_for_model
|
||||
|
||||
@@ -345,7 +185,7 @@ class DjangoContext:
|
||||
expected_types[field_name] = AnyType(TypeOfAny.unannotated)
|
||||
continue
|
||||
|
||||
is_nullable = self.fields_context.get_field_nullability(field, method)
|
||||
is_nullable = self.get_field_nullability(field, method)
|
||||
foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info,
|
||||
'_pyi_private_set_type',
|
||||
is_nullable=is_nullable)
|
||||
@@ -375,3 +215,157 @@ class DjangoContext:
|
||||
all_model_bases.add(helpers.get_class_fullname(base_cls))
|
||||
|
||||
return all_model_bases
|
||||
|
||||
def get_attname(self, field: Field) -> str:
|
||||
attname = field.attname
|
||||
return attname
|
||||
|
||||
def get_field_nullability(self, field: Field, method: Optional[str]) -> bool:
|
||||
nullable = field.null
|
||||
if not nullable and isinstance(field, CharField) and field.blank:
|
||||
return True
|
||||
if method == '__init__':
|
||||
if field.primary_key or isinstance(field, ForeignKey):
|
||||
return True
|
||||
if method == 'create':
|
||||
if isinstance(field, AutoField):
|
||||
return True
|
||||
if field.has_default():
|
||||
return True
|
||||
return nullable
|
||||
|
||||
def get_field_set_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType:
|
||||
""" Get a type of __set__ for this specific Django field. """
|
||||
target_field = field
|
||||
if isinstance(field, ForeignKey):
|
||||
target_field = field.target_field
|
||||
|
||||
field_info = helpers.lookup_class_typeinfo(api, target_field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
|
||||
is_nullable=self.get_field_nullability(field, method))
|
||||
if isinstance(target_field, ArrayField):
|
||||
argument_field_type = self.get_field_set_type(api, target_field.base_field, method=method)
|
||||
field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type)
|
||||
return field_set_type
|
||||
|
||||
def get_field_get_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType:
|
||||
""" Get a type of __get__ for this specific Django field. """
|
||||
field_info = helpers.lookup_class_typeinfo(api, field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
|
||||
is_nullable = self.get_field_nullability(field, method)
|
||||
if isinstance(field, RelatedField):
|
||||
related_model_cls = self.get_field_related_model_cls(field)
|
||||
|
||||
if method == 'values':
|
||||
primary_key_field = self.get_primary_key_field(related_model_cls)
|
||||
return self.get_field_get_type(api, primary_key_field, method=method)
|
||||
|
||||
model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
|
||||
if model_info is None:
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
|
||||
return Instance(model_info, [])
|
||||
else:
|
||||
return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
||||
is_nullable=is_nullable)
|
||||
|
||||
def get_field_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) -> Type[Model]:
|
||||
if isinstance(field, RelatedField):
|
||||
related_model_cls = field.remote_field.model
|
||||
else:
|
||||
related_model_cls = field.field.model
|
||||
|
||||
if isinstance(related_model_cls, str):
|
||||
if related_model_cls == 'self':
|
||||
# same model
|
||||
related_model_cls = field.model
|
||||
elif '.' not in related_model_cls:
|
||||
# same file model
|
||||
related_model_fullname = field.model.__module__ + '.' + related_model_cls
|
||||
related_model_cls = self.get_model_class_by_fullname(related_model_fullname)
|
||||
else:
|
||||
related_model_cls = self.apps_registry.get_model(related_model_cls)
|
||||
|
||||
return related_model_cls
|
||||
|
||||
def _resolve_field_from_parts(self, field_parts: Iterable[str], model_cls: Type[Model]) -> Field:
|
||||
currently_observed_model = model_cls
|
||||
field = None
|
||||
for field_part in field_parts:
|
||||
if field_part == 'pk':
|
||||
field = self.get_primary_key_field(currently_observed_model)
|
||||
continue
|
||||
|
||||
field = currently_observed_model._meta.get_field(field_part)
|
||||
if isinstance(field, RelatedField):
|
||||
currently_observed_model = field.related_model
|
||||
model_name = currently_observed_model._meta.model_name
|
||||
if (model_name is not None
|
||||
and field_part == (model_name + '_id')):
|
||||
field = self.get_primary_key_field(currently_observed_model)
|
||||
|
||||
if isinstance(field, ForeignObjectRel):
|
||||
currently_observed_model = field.related_model
|
||||
|
||||
assert field is not None
|
||||
return field
|
||||
|
||||
def resolve_lookup_into_field(self, model_cls: Type[Model], lookup: str) -> Field:
|
||||
query = Query(model_cls)
|
||||
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
|
||||
if lookup_parts:
|
||||
raise LookupsAreUnsupported()
|
||||
|
||||
return self._resolve_field_from_parts(field_parts, model_cls)
|
||||
|
||||
def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model], lookup: str) -> MypyType:
|
||||
query = Query(model_cls)
|
||||
try:
|
||||
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
|
||||
if is_expression:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
except FieldError as exc:
|
||||
ctx.api.fail(exc.args[0], ctx.context)
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
field = self._resolve_field_from_parts(field_parts, model_cls)
|
||||
|
||||
lookup_cls = None
|
||||
if lookup_parts:
|
||||
lookup = lookup_parts[-1]
|
||||
lookup_cls = field.get_lookup(lookup)
|
||||
if lookup_cls is None:
|
||||
# unknown lookup
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
if lookup_cls is None or isinstance(lookup_cls, Exact):
|
||||
return self.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field)
|
||||
|
||||
assert lookup_cls is not None
|
||||
|
||||
lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls)
|
||||
if lookup_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
for lookup_base in helpers.iter_bases(lookup_info):
|
||||
if lookup_base.args and isinstance(lookup_base.args[0], Instance):
|
||||
lookup_type: MypyType = lookup_base.args[0]
|
||||
# if it's Field, consider lookup_type a __get__ of current field
|
||||
if (isinstance(lookup_type, Instance)
|
||||
and lookup_type.type.fullname() == fullnames.FIELD_FULLNAME):
|
||||
field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__)
|
||||
if field_info is None:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
|
||||
is_nullable=field.null)
|
||||
return lookup_type
|
||||
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
def resolve_f_expression_type(self, f_expression_type: Instance) -> MypyType:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
@@ -37,3 +37,5 @@ RELATED_FIELDS_CLASSES = {
|
||||
MIGRATION_CLASS_FULLNAME = 'django.db.migrations.migration.Migration'
|
||||
OPTIONS_CLASS_FULLNAME = 'django.db.models.options.Options'
|
||||
HTTPREQUEST_CLASS_FULLNAME = 'django.http.request.HttpRequest'
|
||||
|
||||
F_EXPRESSION_FULLNAME = 'django.db.models.expressions.F'
|
||||
|
||||
@@ -284,3 +284,11 @@ def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionCont
|
||||
def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool:
|
||||
return (info.fullname() in django_context.model_base_classes
|
||||
or info.has_base(fullnames.MODEL_CLASS_FULLNAME))
|
||||
|
||||
|
||||
def check_types_compatible(ctx: Union[FunctionContext, MethodContext],
|
||||
*, expected_type: MypyType, actual_type: MypyType, error_message: str) -> None:
|
||||
api = get_typechecker_api(ctx)
|
||||
api.check_subtype(actual_type, expected_type,
|
||||
ctx.context, error_message,
|
||||
'got', 'expected')
|
||||
|
||||
@@ -11,6 +11,7 @@ from mypy.plugin import (
|
||||
)
|
||||
from mypy.types import Type as MypyType
|
||||
|
||||
import mypy_django_plugin.transformers.orm_lookups
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
from mypy_django_plugin.transformers import (
|
||||
@@ -148,13 +149,13 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
# forward relations
|
||||
for field in self.django_context.get_model_fields(model_class):
|
||||
if isinstance(field, RelatedField):
|
||||
related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
|
||||
related_model_cls = self.django_context.get_field_related_model_cls(field)
|
||||
related_model_module = related_model_cls.__module__
|
||||
if related_model_module != file.fullname():
|
||||
deps.add(self._new_dependency(related_model_module))
|
||||
# reverse relations
|
||||
for relation in model_class._meta.related_objects:
|
||||
related_model_cls = self.django_context.fields_context.get_related_model_cls(relation)
|
||||
related_model_cls = self.django_context.get_field_related_model_cls(relation)
|
||||
related_model_module = related_model_cls.__module__
|
||||
if related_model_module != file.fullname():
|
||||
deps.add(self._new_dependency(related_model_module))
|
||||
@@ -210,7 +211,8 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
if class_fullname in manager_classes and method_name == 'create':
|
||||
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
|
||||
if class_fullname in manager_classes and method_name in {'filter', 'get', 'exclude'}:
|
||||
return partial(init_create.typecheck_queryset_filter, django_context=self.django_context)
|
||||
return partial(mypy_django_plugin.transformers.orm_lookups.typecheck_queryset_filter,
|
||||
django_context=self.django_context)
|
||||
return None
|
||||
|
||||
def get_base_class_hook(self, fullname: str
|
||||
|
||||
@@ -44,7 +44,7 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
|
||||
|
||||
assert isinstance(current_field, RelatedField)
|
||||
|
||||
related_model_cls = django_context.fields_context.get_related_model_cls(current_field)
|
||||
related_model_cls = django_context.get_field_related_model_cls(current_field)
|
||||
|
||||
related_model = related_model_cls
|
||||
related_model_to_set = related_model_cls
|
||||
|
||||
@@ -6,7 +6,7 @@ from mypy.types import Instance
|
||||
from mypy.types import Type as MypyType
|
||||
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
from mypy_django_plugin.lib import helpers
|
||||
|
||||
|
||||
def get_actual_types(ctx: Union[MethodContext, FunctionContext],
|
||||
@@ -30,12 +30,6 @@ def get_actual_types(ctx: Union[MethodContext, FunctionContext],
|
||||
return actual_types
|
||||
|
||||
|
||||
def check_types_compatible(ctx, *, expected_type, actual_type, error_message):
|
||||
ctx.api.check_subtype(actual_type, expected_type,
|
||||
ctx.context, error_message,
|
||||
'got', 'expected')
|
||||
|
||||
|
||||
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
|
||||
model_cls: Type[Model], method: str) -> MypyType:
|
||||
typechecker_api = helpers.get_typechecker_api(ctx)
|
||||
@@ -48,7 +42,7 @@ def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_co
|
||||
model_cls.__name__),
|
||||
ctx.context)
|
||||
continue
|
||||
check_types_compatible(ctx,
|
||||
helpers.check_types_compatible(ctx,
|
||||
expected_type=expected_types[actual_name],
|
||||
actual_type=actual_type,
|
||||
error_message='Incompatible type for "{}" of "{}"'.format(actual_name,
|
||||
@@ -79,40 +73,3 @@ def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: Djan
|
||||
return ctx.default_return_type
|
||||
|
||||
return typecheck_model_method(ctx, django_context, model_cls, 'create')
|
||||
|
||||
|
||||
def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||
lookup_kwargs = ctx.arg_names[1]
|
||||
provided_lookup_types = ctx.arg_types[1]
|
||||
|
||||
assert isinstance(ctx.type, Instance)
|
||||
|
||||
if not ctx.type.args or not isinstance(ctx.type.args[0], Instance):
|
||||
return ctx.default_return_type
|
||||
|
||||
model_cls_fullname = ctx.type.args[0].type.fullname()
|
||||
model_cls = django_context.get_model_class_by_fullname(model_cls_fullname)
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
for lookup_kwarg, provided_type in zip(lookup_kwargs, provided_lookup_types):
|
||||
if lookup_kwarg is None:
|
||||
continue
|
||||
# Combinables are not supported yet
|
||||
if (isinstance(provided_type, Instance)
|
||||
and provided_type.type.has_base('django.db.models.expressions.Combinable')):
|
||||
continue
|
||||
|
||||
lookup_type = django_context.lookups_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg)
|
||||
# Managers as provided_type is not supported yet
|
||||
if (isinstance(provided_type, Instance)
|
||||
and helpers.has_any_of_bases(provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME,
|
||||
fullnames.QUERYSET_CLASS_FULLNAME))):
|
||||
return ctx.default_return_type
|
||||
|
||||
check_types_compatible(ctx,
|
||||
expected_type=lookup_type,
|
||||
actual_type=provided_type,
|
||||
error_message=f'Incompatible type for lookup {lookup_kwarg!r}:')
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
@@ -99,10 +99,10 @@ class AddRelatedModelsId(ModelClassInitializer):
|
||||
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, ForeignKey):
|
||||
related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
|
||||
related_model_cls = self.django_context.get_field_related_model_cls(field)
|
||||
rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
|
||||
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__)
|
||||
is_nullable = self.django_context.fields_context.get_field_nullability(field, None)
|
||||
is_nullable = self.django_context.get_field_nullability(field, None)
|
||||
set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
|
||||
self.add_new_node_to_model_class(field.attname,
|
||||
Instance(field_info, [set_type, get_type]))
|
||||
@@ -162,7 +162,7 @@ class AddRelatedManagers(ModelClassInitializer):
|
||||
# no reverse accessor
|
||||
continue
|
||||
|
||||
related_model_cls = self.django_context.fields_context.get_related_model_cls(relation)
|
||||
related_model_cls = self.django_context.get_field_related_model_cls(relation)
|
||||
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls)
|
||||
|
||||
if isinstance(relation, OneToOneRel):
|
||||
|
||||
51
mypy_django_plugin/transformers/orm_lookups.py
Normal file
51
mypy_django_plugin/transformers/orm_lookups.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from mypy.plugin import MethodContext
|
||||
from mypy.types import AnyType, Instance
|
||||
from mypy.types import Type as MypyType
|
||||
from mypy.types import TypeOfAny
|
||||
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
|
||||
|
||||
def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||
lookup_kwargs = ctx.arg_names[1]
|
||||
provided_lookup_types = ctx.arg_types[1]
|
||||
|
||||
assert isinstance(ctx.type, Instance)
|
||||
|
||||
if not ctx.type.args or not isinstance(ctx.type.args[0], Instance):
|
||||
return ctx.default_return_type
|
||||
|
||||
model_cls_fullname = ctx.type.args[0].type.fullname()
|
||||
model_cls = django_context.get_model_class_by_fullname(model_cls_fullname)
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
for lookup_kwarg, provided_type in zip(lookup_kwargs, provided_lookup_types):
|
||||
if lookup_kwarg is None:
|
||||
continue
|
||||
if (isinstance(provided_type, Instance)
|
||||
and provided_type.type.has_base('django.db.models.expressions.Combinable')):
|
||||
provided_type = resolve_combinable_type(provided_type, django_context)
|
||||
|
||||
lookup_type = django_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg)
|
||||
# Managers as provided_type is not supported yet
|
||||
if (isinstance(provided_type, Instance)
|
||||
and helpers.has_any_of_bases(provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME,
|
||||
fullnames.QUERYSET_CLASS_FULLNAME))):
|
||||
return ctx.default_return_type
|
||||
|
||||
helpers.check_types_compatible(ctx,
|
||||
expected_type=lookup_type,
|
||||
actual_type=provided_type,
|
||||
error_message=f'Incompatible type for lookup {lookup_kwarg!r}:')
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def resolve_combinable_type(combinable_type: Instance, django_context: DjangoContext) -> MypyType:
|
||||
if combinable_type.type.fullname() != fullnames.F_EXPRESSION_FULLNAME:
|
||||
# Combinables aside from F expressions are unsupported
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
return django_context.resolve_f_expression_type(combinable_type)
|
||||
@@ -40,7 +40,7 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
|
||||
def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
|
||||
*, method: str, lookup: str) -> Optional[MypyType]:
|
||||
try:
|
||||
lookup_field = django_context.lookups_context.resolve_lookup_info_field(model_cls, lookup)
|
||||
lookup_field = django_context.resolve_lookup_into_field(model_cls, lookup)
|
||||
except FieldError as exc:
|
||||
ctx.api.fail(exc.args[0], ctx.context)
|
||||
return None
|
||||
@@ -48,10 +48,10 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup:
|
||||
related_model_cls = django_context.fields_context.get_related_model_cls(lookup_field)
|
||||
related_model_cls = django_context.get_field_related_model_cls(lookup_field)
|
||||
lookup_field = django_context.get_primary_key_field(related_model_cls)
|
||||
|
||||
field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx),
|
||||
field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx),
|
||||
lookup_field, method=method)
|
||||
return field_get_type
|
||||
|
||||
@@ -73,7 +73,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
|
||||
elif named:
|
||||
column_types: 'OrderedDict[str, MypyType]' = OrderedDict()
|
||||
for field in django_context.get_model_fields(model_cls):
|
||||
column_type = django_context.fields_context.get_field_get_type(typechecker_api, field,
|
||||
column_type = django_context.get_field_get_type(typechecker_api, field,
|
||||
method='values_list')
|
||||
column_types[field.attname] = column_type
|
||||
return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
|
||||
|
||||
@@ -213,3 +213,44 @@
|
||||
pass
|
||||
class Profile(models.Model):
|
||||
user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile')
|
||||
|
||||
|
||||
# TODO
|
||||
- case: f_expression_simple_case
|
||||
main: |
|
||||
from myapp.models import User
|
||||
from django.db import models
|
||||
User.objects.filter(username=models.F('username2'))
|
||||
User.objects.filter(username=models.F('age'))
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
|
||||
class User(models.Model):
|
||||
username = models.TextField()
|
||||
username2 = models.TextField()
|
||||
|
||||
age = models.IntegerField()
|
||||
|
||||
|
||||
# TODO
|
||||
- case: f_expression_with_expression_math_is_not_supported
|
||||
main: |
|
||||
from myapp.models import User
|
||||
from django.db import models
|
||||
User.objects.filter(username=models.F('username2') + 'hello')
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
class User(models.Model):
|
||||
username = models.TextField()
|
||||
username2 = models.TextField()
|
||||
age = models.IntegerField()
|
||||
|
||||
Reference in New Issue
Block a user