diff --git a/django-stubs/core/handlers/base.pyi b/django-stubs/core/handlers/base.pyi index 2be4169..93461d1 100644 --- a/django-stubs/core/handlers/base.pyi +++ b/django-stubs/core/handlers/base.pyi @@ -6,10 +6,6 @@ from django.http.response import HttpResponse, HttpResponseBase logger: Any class BaseHandler: - _view_middleware: None = ... - _template_response_middleware: None = ... - _exception_middleware: None = ... - _middleware_chain: None = ... def load_middleware(self) -> None: ... def make_view_atomic(self, view: Callable) -> Callable: ... def get_exception_response(self, request: Any, resolver: Any, status_code: Any, exception: Any): ... diff --git a/django-stubs/core/management/utils.pyi b/django-stubs/core/management/utils.pyi index 27e13d3..23b5696 100644 --- a/django-stubs/core/management/utils.pyi +++ b/django-stubs/core/management/utils.pyi @@ -5,6 +5,6 @@ from django.db.models.base import Model def popen_wrapper(args: List[str], stdout_encoding: str = ...) -> Tuple[str, str, int]: ... def handle_extensions(extensions: List[str]) -> Set[str]: ... -def find_command(cmd: str, path: None = ..., pathext: None = ...) -> Optional[str]: ... +def find_command(cmd: str, path: Optional[str] = ..., pathext: Optional[str] = ...) -> Optional[str]: ... def get_random_secret_key(): ... def parse_apps_and_model_labels(labels: List[str]) -> Tuple[Set[Type[Model]], Set[AppConfig]]: ... diff --git a/django-stubs/db/backends/utils.pyi b/django-stubs/db/backends/utils.pyi index a6eba96..ffbcdfa 100644 --- a/django-stubs/db/backends/utils.pyi +++ b/django-stubs/db/backends/utils.pyi @@ -1,6 +1,7 @@ +import types from datetime import date, datetime, time from decimal import Decimal -from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Union, Type from uuid import UUID logger: Any @@ -16,7 +17,12 @@ class CursorWrapper: def __getattr__(self, attr: str) -> Any: ... def __iter__(self) -> None: ... def __enter__(self) -> CursorWrapper: ... - def __exit__(self, type: None, value: None, traceback: None) -> None: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + tb: Optional[types.TracebackType], + ) -> None: ... def callproc(self, procname: str, params: List[Any] = ..., kparams: Dict[str, int] = ...) -> Any: ... def execute( self, sql: str, params: Optional[Union[Sequence[_SQLType], Mapping[str, _SQLType]]] = ... diff --git a/django-stubs/db/models/fields/related.pyi b/django-stubs/db/models/fields/related.pyi index 74ab259..aab35fb 100644 --- a/django-stubs/db/models/fields/related.pyi +++ b/django-stubs/db/models/fields/related.pyi @@ -80,9 +80,9 @@ class ForeignObject(RelatedField[_ST, _GT]): on_delete: Callable[..., None], from_fields: Sequence[str], to_fields: Sequence[str], - rel: None = ..., + rel: Optional[ForeignObjectRel] = ..., related_name: Optional[str] = ..., - related_query_name: None = ..., + related_query_name: Optional[str] = ..., limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any]]] = ..., parent_link: bool = ..., db_constraint: bool = ..., diff --git a/django-stubs/db/models/options.pyi b/django-stubs/db/models/options.pyi index 47458d8..84a0607 100644 --- a/django-stubs/db/models/options.pyi +++ b/django-stubs/db/models/options.pyi @@ -71,7 +71,7 @@ class Options(Generic[_M]): abstract: bool = ... managed: bool = ... proxy: bool = ... - proxy_for_model: None = ... + proxy_for_model: Optional[Type[Model]] = ... concrete_model: Optional[Type[Model]] = ... swappable: None = ... parents: collections.OrderedDict = ... diff --git a/django-stubs/db/models/query.pyi b/django-stubs/db/models/query.pyi index 31caebf..e086582 100644 --- a/django-stubs/db/models/query.pyi +++ b/django-stubs/db/models/query.pyi @@ -92,7 +92,7 @@ class QuerySet(Generic[_T], Collection[_T], Sized): ) -> ValuesQuerySet[_T, Any]: ... def dates(self, field_name: str, kind: str, order: str = ...) -> ValuesQuerySet[_T, datetime.date]: ... def datetimes( - self, field_name: str, kind: str, order: str = ..., tzinfo: None = ... + self, field_name: str, kind: str, order: str = ..., tzinfo: Optional[datetime.tzinfo] = ... ) -> ValuesQuerySet[_T, datetime.datetime]: ... def none(self) -> QuerySet[_T]: ... def all(self) -> QuerySet[_T]: ... diff --git a/django-stubs/test/testcases.pyi b/django-stubs/test/testcases.pyi index 9c17811..8769a95 100644 --- a/django-stubs/test/testcases.pyi +++ b/django-stubs/test/testcases.pyi @@ -119,15 +119,17 @@ class SimpleTestCase(unittest.TestCase): field_kwargs: None = ..., empty_value: str = ..., ) -> Any: ... - def assertHTMLEqual(self, html1: str, html2: str, msg: None = ...) -> None: ... - def assertHTMLNotEqual(self, html1: str, html2: str, msg: None = ...) -> None: ... + def assertHTMLEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ... + def assertHTMLNotEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ... def assertInHTML( self, needle: str, haystack: SafeText, count: Optional[int] = ..., msg_prefix: str = ... ) -> None: ... - def assertJSONEqual(self, raw: str, expected_data: Union[Dict[str, str], bool, str], msg: None = ...) -> None: ... - def assertJSONNotEqual(self, raw: str, expected_data: str, msg: None = ...) -> None: ... - def assertXMLEqual(self, xml1: str, xml2: str, msg: None = ...) -> None: ... - def assertXMLNotEqual(self, xml1: str, xml2: str, msg: None = ...) -> None: ... + def assertJSONEqual( + self, raw: str, expected_data: Union[Dict[str, str], bool, str], msg: Optional[str] = ... + ) -> None: ... + def assertJSONNotEqual(self, raw: str, expected_data: str, msg: Optional[str] = ...) -> None: ... + def assertXMLEqual(self, xml1: str, xml2: str, msg: Optional[str] = ...) -> None: ... + def assertXMLNotEqual(self, xml1: str, xml2: str, msg: Optional[str] = ...) -> None: ... class TransactionTestCase(SimpleTestCase): reset_sequences: bool = ... @@ -141,7 +143,7 @@ class TransactionTestCase(SimpleTestCase): values: Union[List[None], List[Tuple[str, str]], List[date], List[int], List[str], Set[str], QuerySet], transform: Union[Callable, Type[str]] = ..., ordered: bool = ..., - msg: None = ..., + msg: Optional[str] = ..., ) -> None: ... def assertNumQueries( self, num: int, func: Optional[Union[Callable, Type[list]]] = ..., *args: Any, using: Any = ..., **kwargs: Any diff --git a/django-stubs/utils/timesince.pyi b/django-stubs/utils/timesince.pyi index 8db447a..04e5b6a 100644 --- a/django-stubs/utils/timesince.pyi +++ b/django-stubs/utils/timesince.pyi @@ -1,8 +1,10 @@ from datetime import date -from typing import Any, Optional +from typing import Any, Optional, Dict -TIME_STRINGS: Any +TIME_STRINGS: Dict[str, str] TIMESINCE_CHUNKS: Any -def timesince(d: date, now: Optional[date] = ..., reversed: bool = ..., time_strings: None = ...) -> str: ... -def timeuntil(d: date, now: Optional[date] = ..., time_strings: None = ...) -> str: ... +def timesince( + d: date, now: Optional[date] = ..., reversed: bool = ..., time_strings: Optional[Dict[str, str]] = ... +) -> str: ... +def timeuntil(d: date, now: Optional[date] = ..., time_strings: Optional[Dict[str, str]] = ...) -> str: ... diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index 2cdd3b2..b8c11c9 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -1,7 +1,7 @@ import os from collections import defaultdict from contextlib import contextmanager -from typing import Dict, Iterator, Optional, Set, TYPE_CHECKING, Tuple, Type +from typing import Dict, Iterator, Optional, Set, TYPE_CHECKING, Tuple, Type, Union from django.core.exceptions import FieldError from django.db.models.base import Model @@ -113,11 +113,13 @@ class DjangoFieldsContext: is_nullable = self.get_field_nullability(field, method) if isinstance(field, RelatedField): + related_model_cls = self.django_context.fields_context.get_related_model_cls(field) + if method == 'values': - primary_key_field = self.django_context.get_primary_key_field(field.related_model) + primary_key_field = self.django_context.get_primary_key_field(related_model_cls) return self.get_field_get_type(api, primary_key_field, method=method) - model_info = helpers.lookup_class_typeinfo(api, field.related_model) + model_info = helpers.lookup_class_typeinfo(api, related_model_cls) if model_info is None: return AnyType(TypeOfAny.unannotated) @@ -126,6 +128,17 @@ class DjangoFieldsContext: return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', is_nullable=is_nullable) + def get_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) -> Type[Model]: + if isinstance(field, RelatedField): + related_model_cls = field.remote_field.model + else: + related_model_cls = field.field.model + + if isinstance(related_model_cls, str): + related_model_cls = self.django_context.apps_registry.get_model(related_model_cls) + + return related_model_cls + class DjangoLookupsContext: def __init__(self, django_context: 'DjangoContext'): @@ -144,12 +157,12 @@ class DjangoLookupsContext: return self.django_context.get_primary_key_field(currently_observed_model) current_field = currently_observed_model._meta.get_field(field_part) + if not isinstance(current_field, (ForeignObjectRel, RelatedField)): + continue + + currently_observed_model = self.django_context.fields_context.get_related_model_cls(current_field) if isinstance(current_field, ForeignObjectRel): - currently_observed_model = current_field.related_model current_field = self.django_context.get_primary_key_field(currently_observed_model) - else: - if isinstance(current_field, RelatedField): - currently_observed_model = current_field.related_model # if it is None, solve_lookup_type() will fail earlier assert current_field is not None @@ -213,10 +226,11 @@ class DjangoContext: from django.contrib.contenttypes.fields import GenericForeignKey expected_types = {} - # add pk - primary_key_field = self.get_primary_key_field(model_cls) - field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method) - expected_types['pk'] = field_set_type + # add pk if not abstract=True + if not model_cls._meta.abstract: + primary_key_field = self.get_primary_key_field(model_cls) + field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method) + expected_types['pk'] = field_set_type for field in model_cls._meta.get_fields(): if isinstance(field, Field): @@ -232,9 +246,9 @@ class DjangoContext: expected_types[field_name] = AnyType(TypeOfAny.unannotated) continue - related_model = field.related_model - if related_model._meta.proxy_for_model: - related_model = field.related_model._meta.proxy_for_model + related_model = self.fields_context.get_related_model_cls(field) + if related_model._meta.proxy_for_model is not None: + related_model = related_model._meta.proxy_for_model related_model_info = helpers.lookup_class_typeinfo(api, related_model) if related_model_info is None: diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 3555e5a..9e7b7a5 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -1,19 +1,15 @@ from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union, cast from mypy import checker from mypy.checker import TypeChecker from mypy.mro import calculate_mro -from mypy.nodes import ( - GDEF, MDEF, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, - SymbolTableNode, TypeInfo, Var, -) +from mypy.nodes import (Block, ClassDef, Expression, GDEF, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, + SymbolTable, SymbolTableNode, TypeInfo, Var) from mypy.plugin import ( AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext, ) -from mypy.types import AnyType, Instance, NoneTyp, TupleType -from mypy.types import Type as MypyType -from mypy.types import TypedDictType, TypeOfAny, UnionType +from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType if TYPE_CHECKING: from mypy_django_plugin.django.context import DjangoContext diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 2c76bf1..3e485da 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -148,12 +148,14 @@ class NewSemanalDjangoPlugin(Plugin): # forward relations for field in self.django_context.get_model_fields(model_class): if isinstance(field, RelatedField): - related_model_module = field.related_model.__module__ + related_model_cls = self.django_context.fields_context.get_related_model_cls(field) + related_model_module = related_model_cls.__module__ if related_model_module != file.fullname(): deps.add(self._new_dependency(related_model_module)) # reverse relations for relation in model_class._meta.related_objects: - related_model_module = relation.related_model.__module__ + related_model_cls = self.django_context.fields_context.get_related_model_cls(relation) + related_model_module = related_model_cls.__module__ if related_model_module != file.fullname(): deps.add(self._new_dependency(related_model_module)) return list(deps) diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index b0793d7..1efcfe1 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -44,9 +44,12 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context assert isinstance(current_field, RelatedField) - related_model = related_model_to_set = current_field.related_model - if related_model_to_set._meta.proxy_for_model: - related_model_to_set = related_model._meta.proxy_for_model + related_model_cls = django_context.fields_context.get_related_model_cls(current_field) + + related_model = related_model_cls + related_model_to_set = related_model_cls + if related_model_to_set._meta.proxy_for_model is not None: + related_model_to_set = related_model_to_set._meta.proxy_for_model typechecker_api = helpers.get_typechecker_api(ctx) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 9318b0c..c1272d7 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -100,7 +100,8 @@ class AddRelatedModelsId(ModelClassInitializer): def run_with_model_cls(self, model_cls: Type[Model]) -> None: for field in model_cls._meta.get_fields(): if isinstance(field, ForeignKey): - rel_primary_key_field = self.django_context.get_primary_key_field(field.related_model) + related_model_cls = self.django_context.fields_context.get_related_model_cls(field) + rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls) field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__) is_nullable = self.django_context.fields_context.get_field_nullability(field, None) set_type, get_type = get_field_descriptor_types(field_info, is_nullable) @@ -156,7 +157,8 @@ class AddManagers(ModelClassInitializer): # no reverse accessor continue - related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(relation.related_model) + related_model_cls = self.django_context.fields_context.get_related_model_cls(relation) + related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls) if isinstance(relation, OneToOneRel): self.add_new_node_to_model_class(attname, Instance(related_model_info, [])) diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index 7a8a7b4..b598fa3 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -42,7 +42,8 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext return None if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup: - lookup_field = django_context.get_primary_key_field(lookup_field.related_model) + related_model_cls = django_context.fields_context.get_related_model_cls(lookup_field) + lookup_field = django_context.get_primary_key_field(related_model_cls) field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx), lookup_field, method=method) diff --git a/test-data/typecheck/models/test_init.yml b/test-data/typecheck/models/test_init.yml index b42ec66..40ac583 100644 --- a/test-data/typecheck/models/test_init.yml +++ b/test-data/typecheck/models/test_init.yml @@ -233,3 +233,25 @@ name = models.CharField(primary_key=True, max_length=100) class Book(models.Model): publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE) + + +- case: init_in_abstract_model_classmethod_should_not_throw_error_for_valid_fields + main: | + from myapp.models import MyModel + MyModel.base_init() + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class AbstractModel(models.Model): + class Meta: + abstract = True + text = models.CharField(max_length=100) + @classmethod + def base_init(cls) -> 'AbstractModel': + return cls(text='mytext') + class MyModel(AbstractModel): + pass \ No newline at end of file