stability fixes

This commit is contained in:
Maxim Kurnikov
2019-08-24 17:04:50 +03:00
parent c91a6d1d5b
commit e95b40ef52
7 changed files with 70 additions and 30 deletions

View File

@@ -1,7 +1,7 @@
import os import os
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, Iterator, Optional, Set, TYPE_CHECKING, Tuple, Type from typing import Dict, Iterator, Optional, Set, TYPE_CHECKING, Tuple, Type, Union
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models.base import Model from django.db.models.base import Model
@@ -113,11 +113,13 @@ class DjangoFieldsContext:
is_nullable = self.get_field_nullability(field, method) is_nullable = self.get_field_nullability(field, method)
if isinstance(field, RelatedField): if isinstance(field, RelatedField):
related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
if method == 'values': if method == 'values':
primary_key_field = self.django_context.get_primary_key_field(field.related_model) primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
return self.get_field_get_type(api, primary_key_field, method=method) return self.get_field_get_type(api, primary_key_field, method=method)
model_info = helpers.lookup_class_typeinfo(api, field.related_model) model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
if model_info is None: if model_info is None:
return AnyType(TypeOfAny.unannotated) return AnyType(TypeOfAny.unannotated)
@@ -126,6 +128,17 @@ class DjangoFieldsContext:
return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
is_nullable=is_nullable) is_nullable=is_nullable)
def get_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) -> Type[Model]:
if isinstance(field, RelatedField):
related_model_cls = field.remote_field.model
else:
related_model_cls = field.field.model
if isinstance(related_model_cls, str):
related_model_cls = self.django_context.apps_registry.get_model(related_model_cls)
return related_model_cls
class DjangoLookupsContext: class DjangoLookupsContext:
def __init__(self, django_context: 'DjangoContext'): def __init__(self, django_context: 'DjangoContext'):
@@ -144,12 +157,12 @@ class DjangoLookupsContext:
return self.django_context.get_primary_key_field(currently_observed_model) return self.django_context.get_primary_key_field(currently_observed_model)
current_field = currently_observed_model._meta.get_field(field_part) current_field = currently_observed_model._meta.get_field(field_part)
if not isinstance(current_field, (ForeignObjectRel, RelatedField)):
continue
currently_observed_model = self.django_context.fields_context.get_related_model_cls(current_field)
if isinstance(current_field, ForeignObjectRel): if isinstance(current_field, ForeignObjectRel):
currently_observed_model = current_field.related_model
current_field = self.django_context.get_primary_key_field(currently_observed_model) current_field = self.django_context.get_primary_key_field(currently_observed_model)
else:
if isinstance(current_field, RelatedField):
currently_observed_model = current_field.related_model
# if it is None, solve_lookup_type() will fail earlier # if it is None, solve_lookup_type() will fail earlier
assert current_field is not None assert current_field is not None
@@ -213,7 +226,8 @@ class DjangoContext:
from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.fields import GenericForeignKey
expected_types = {} expected_types = {}
# add pk # add pk if not abstract=True
if not model_cls._meta.abstract:
primary_key_field = self.get_primary_key_field(model_cls) primary_key_field = self.get_primary_key_field(model_cls)
field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method) field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method)
expected_types['pk'] = field_set_type expected_types['pk'] = field_set_type
@@ -232,9 +246,9 @@ class DjangoContext:
expected_types[field_name] = AnyType(TypeOfAny.unannotated) expected_types[field_name] = AnyType(TypeOfAny.unannotated)
continue continue
related_model = field.related_model related_model = self.fields_context.get_related_model_cls(field)
if related_model._meta.proxy_for_model: if related_model._meta.proxy_for_model is not None:
related_model = field.related_model._meta.proxy_for_model related_model = related_model._meta.proxy_for_model
related_model_info = helpers.lookup_class_typeinfo(api, related_model) related_model_info = helpers.lookup_class_typeinfo(api, related_model)
if related_model_info is None: if related_model_info is None:

View File

@@ -1,19 +1,15 @@
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union, cast
from mypy import checker from mypy import checker
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.mro import calculate_mro from mypy.mro import calculate_mro
from mypy.nodes import ( from mypy.nodes import (Block, ClassDef, Expression, GDEF, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode,
GDEF, MDEF, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, SymbolTable, SymbolTableNode, TypeInfo, Var)
SymbolTableNode, TypeInfo, Var,
)
from mypy.plugin import ( from mypy.plugin import (
AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext, AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext,
) )
from mypy.types import AnyType, Instance, NoneTyp, TupleType from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType
from mypy.types import Type as MypyType
from mypy.types import TypedDictType, TypeOfAny, UnionType
if TYPE_CHECKING: if TYPE_CHECKING:
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext

View File

@@ -148,12 +148,14 @@ class NewSemanalDjangoPlugin(Plugin):
# forward relations # forward relations
for field in self.django_context.get_model_fields(model_class): for field in self.django_context.get_model_fields(model_class):
if isinstance(field, RelatedField): if isinstance(field, RelatedField):
related_model_module = field.related_model.__module__ related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname(): if related_model_module != file.fullname():
deps.add(self._new_dependency(related_model_module)) deps.add(self._new_dependency(related_model_module))
# reverse relations # reverse relations
for relation in model_class._meta.related_objects: for relation in model_class._meta.related_objects:
related_model_module = relation.related_model.__module__ related_model_cls = self.django_context.fields_context.get_related_model_cls(relation)
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname(): if related_model_module != file.fullname():
deps.add(self._new_dependency(related_model_module)) deps.add(self._new_dependency(related_model_module))
return list(deps) return list(deps)

View File

@@ -44,9 +44,12 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
assert isinstance(current_field, RelatedField) assert isinstance(current_field, RelatedField)
related_model = related_model_to_set = current_field.related_model related_model_cls = django_context.fields_context.get_related_model_cls(current_field)
if related_model_to_set._meta.proxy_for_model:
related_model_to_set = related_model._meta.proxy_for_model related_model = related_model_cls
related_model_to_set = related_model_cls
if related_model_to_set._meta.proxy_for_model is not None:
related_model_to_set = related_model_to_set._meta.proxy_for_model
typechecker_api = helpers.get_typechecker_api(ctx) typechecker_api = helpers.get_typechecker_api(ctx)

View File

@@ -100,7 +100,8 @@ class AddRelatedModelsId(ModelClassInitializer):
def run_with_model_cls(self, model_cls: Type[Model]) -> None: def run_with_model_cls(self, model_cls: Type[Model]) -> None:
for field in model_cls._meta.get_fields(): for field in model_cls._meta.get_fields():
if isinstance(field, ForeignKey): if isinstance(field, ForeignKey):
rel_primary_key_field = self.django_context.get_primary_key_field(field.related_model) related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__) field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__)
is_nullable = self.django_context.fields_context.get_field_nullability(field, None) is_nullable = self.django_context.fields_context.get_field_nullability(field, None)
set_type, get_type = get_field_descriptor_types(field_info, is_nullable) set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
@@ -156,7 +157,8 @@ class AddManagers(ModelClassInitializer):
# no reverse accessor # no reverse accessor
continue continue
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(relation.related_model) related_model_cls = self.django_context.fields_context.get_related_model_cls(relation)
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls)
if isinstance(relation, OneToOneRel): if isinstance(relation, OneToOneRel):
self.add_new_node_to_model_class(attname, Instance(related_model_info, [])) self.add_new_node_to_model_class(attname, Instance(related_model_info, []))

View File

@@ -42,7 +42,8 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext
return None return None
if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup: if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup:
lookup_field = django_context.get_primary_key_field(lookup_field.related_model) related_model_cls = django_context.fields_context.get_related_model_cls(lookup_field)
lookup_field = django_context.get_primary_key_field(related_model_cls)
field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx), field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx),
lookup_field, method=method) lookup_field, method=method)

View File

@@ -233,3 +233,25 @@
name = models.CharField(primary_key=True, max_length=100) name = models.CharField(primary_key=True, max_length=100)
class Book(models.Model): class Book(models.Model):
publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE) publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE)
- case: init_in_abstract_model_classmethod_should_not_throw_error_for_valid_fields
main: |
from myapp.models import MyModel
MyModel.base_init()
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class AbstractModel(models.Model):
class Meta:
abstract = True
text = models.CharField(max_length=100)
@classmethod
def base_init(cls) -> 'AbstractModel':
return cls(text='mytext')
class MyModel(AbstractModel):
pass