From e95b40ef524f5a11bae0050d807bda05c10e55cd Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Sat, 24 Aug 2019 17:04:50 +0300 Subject: [PATCH] stability fixes --- mypy_django_plugin/django/context.py | 42 +++++++++++++------- mypy_django_plugin/lib/helpers.py | 12 ++---- mypy_django_plugin/main.py | 6 ++- mypy_django_plugin/transformers/fields.py | 9 +++-- mypy_django_plugin/transformers/models.py | 6 ++- mypy_django_plugin/transformers/querysets.py | 3 +- test-data/typecheck/models/test_init.yml | 22 ++++++++++ 7 files changed, 70 insertions(+), 30 deletions(-) diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index 2cdd3b2..b8c11c9 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -1,7 +1,7 @@ import os from collections import defaultdict from contextlib import contextmanager -from typing import Dict, Iterator, Optional, Set, TYPE_CHECKING, Tuple, Type +from typing import Dict, Iterator, Optional, Set, TYPE_CHECKING, Tuple, Type, Union from django.core.exceptions import FieldError from django.db.models.base import Model @@ -113,11 +113,13 @@ class DjangoFieldsContext: is_nullable = self.get_field_nullability(field, method) if isinstance(field, RelatedField): + related_model_cls = self.django_context.fields_context.get_related_model_cls(field) + if method == 'values': - primary_key_field = self.django_context.get_primary_key_field(field.related_model) + primary_key_field = self.django_context.get_primary_key_field(related_model_cls) return self.get_field_get_type(api, primary_key_field, method=method) - model_info = helpers.lookup_class_typeinfo(api, field.related_model) + model_info = helpers.lookup_class_typeinfo(api, related_model_cls) if model_info is None: return AnyType(TypeOfAny.unannotated) @@ -126,6 +128,17 @@ class DjangoFieldsContext: return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', is_nullable=is_nullable) + def get_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) -> Type[Model]: + if isinstance(field, RelatedField): + related_model_cls = field.remote_field.model + else: + related_model_cls = field.field.model + + if isinstance(related_model_cls, str): + related_model_cls = self.django_context.apps_registry.get_model(related_model_cls) + + return related_model_cls + class DjangoLookupsContext: def __init__(self, django_context: 'DjangoContext'): @@ -144,12 +157,12 @@ class DjangoLookupsContext: return self.django_context.get_primary_key_field(currently_observed_model) current_field = currently_observed_model._meta.get_field(field_part) + if not isinstance(current_field, (ForeignObjectRel, RelatedField)): + continue + + currently_observed_model = self.django_context.fields_context.get_related_model_cls(current_field) if isinstance(current_field, ForeignObjectRel): - currently_observed_model = current_field.related_model current_field = self.django_context.get_primary_key_field(currently_observed_model) - else: - if isinstance(current_field, RelatedField): - currently_observed_model = current_field.related_model # if it is None, solve_lookup_type() will fail earlier assert current_field is not None @@ -213,10 +226,11 @@ class DjangoContext: from django.contrib.contenttypes.fields import GenericForeignKey expected_types = {} - # add pk - primary_key_field = self.get_primary_key_field(model_cls) - field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method) - expected_types['pk'] = field_set_type + # add pk if not abstract=True + if not model_cls._meta.abstract: + primary_key_field = self.get_primary_key_field(model_cls) + field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method) + expected_types['pk'] = field_set_type for field in model_cls._meta.get_fields(): if isinstance(field, Field): @@ -232,9 +246,9 @@ class DjangoContext: expected_types[field_name] = AnyType(TypeOfAny.unannotated) continue - related_model = field.related_model - if related_model._meta.proxy_for_model: - related_model = field.related_model._meta.proxy_for_model + related_model = self.fields_context.get_related_model_cls(field) + if related_model._meta.proxy_for_model is not None: + related_model = related_model._meta.proxy_for_model related_model_info = helpers.lookup_class_typeinfo(api, related_model) if related_model_info is None: diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 3555e5a..9e7b7a5 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -1,19 +1,15 @@ from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union, cast from mypy import checker from mypy.checker import TypeChecker from mypy.mro import calculate_mro -from mypy.nodes import ( - GDEF, MDEF, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, - SymbolTableNode, TypeInfo, Var, -) +from mypy.nodes import (Block, ClassDef, Expression, GDEF, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, + SymbolTable, SymbolTableNode, TypeInfo, Var) from mypy.plugin import ( AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext, ) -from mypy.types import AnyType, Instance, NoneTyp, TupleType -from mypy.types import Type as MypyType -from mypy.types import TypedDictType, TypeOfAny, UnionType +from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType if TYPE_CHECKING: from mypy_django_plugin.django.context import DjangoContext diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 2c76bf1..3e485da 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -148,12 +148,14 @@ class NewSemanalDjangoPlugin(Plugin): # forward relations for field in self.django_context.get_model_fields(model_class): if isinstance(field, RelatedField): - related_model_module = field.related_model.__module__ + related_model_cls = self.django_context.fields_context.get_related_model_cls(field) + related_model_module = related_model_cls.__module__ if related_model_module != file.fullname(): deps.add(self._new_dependency(related_model_module)) # reverse relations for relation in model_class._meta.related_objects: - related_model_module = relation.related_model.__module__ + related_model_cls = self.django_context.fields_context.get_related_model_cls(relation) + related_model_module = related_model_cls.__module__ if related_model_module != file.fullname(): deps.add(self._new_dependency(related_model_module)) return list(deps) diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index b0793d7..1efcfe1 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -44,9 +44,12 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context assert isinstance(current_field, RelatedField) - related_model = related_model_to_set = current_field.related_model - if related_model_to_set._meta.proxy_for_model: - related_model_to_set = related_model._meta.proxy_for_model + related_model_cls = django_context.fields_context.get_related_model_cls(current_field) + + related_model = related_model_cls + related_model_to_set = related_model_cls + if related_model_to_set._meta.proxy_for_model is not None: + related_model_to_set = related_model_to_set._meta.proxy_for_model typechecker_api = helpers.get_typechecker_api(ctx) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 9318b0c..c1272d7 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -100,7 +100,8 @@ class AddRelatedModelsId(ModelClassInitializer): def run_with_model_cls(self, model_cls: Type[Model]) -> None: for field in model_cls._meta.get_fields(): if isinstance(field, ForeignKey): - rel_primary_key_field = self.django_context.get_primary_key_field(field.related_model) + related_model_cls = self.django_context.fields_context.get_related_model_cls(field) + rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls) field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__) is_nullable = self.django_context.fields_context.get_field_nullability(field, None) set_type, get_type = get_field_descriptor_types(field_info, is_nullable) @@ -156,7 +157,8 @@ class AddManagers(ModelClassInitializer): # no reverse accessor continue - related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(relation.related_model) + related_model_cls = self.django_context.fields_context.get_related_model_cls(relation) + related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls) if isinstance(relation, OneToOneRel): self.add_new_node_to_model_class(attname, Instance(related_model_info, [])) diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index 7a8a7b4..b598fa3 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -42,7 +42,8 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext return None if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup: - lookup_field = django_context.get_primary_key_field(lookup_field.related_model) + related_model_cls = django_context.fields_context.get_related_model_cls(lookup_field) + lookup_field = django_context.get_primary_key_field(related_model_cls) field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx), lookup_field, method=method) diff --git a/test-data/typecheck/models/test_init.yml b/test-data/typecheck/models/test_init.yml index b42ec66..40ac583 100644 --- a/test-data/typecheck/models/test_init.yml +++ b/test-data/typecheck/models/test_init.yml @@ -233,3 +233,25 @@ name = models.CharField(primary_key=True, max_length=100) class Book(models.Model): publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE) + + +- case: init_in_abstract_model_classmethod_should_not_throw_error_for_valid_fields + main: | + from myapp.models import MyModel + MyModel.base_init() + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class AbstractModel(models.Model): + class Meta: + abstract = True + text = models.CharField(max_length=100) + @classmethod + def base_init(cls) -> 'AbstractModel': + return cls(text='mytext') + class MyModel(AbstractModel): + pass \ No newline at end of file