diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index b7e26e5..8942cf4 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -1,11 +1,12 @@ import os -from typing import Callable, Optional, Set, Union, cast, Dict +from typing import Callable, Dict, Optional, Union, cast from mypy.checker import TypeChecker from mypy.nodes import MemberExpr, TypeInfo from mypy.options import Options from mypy.plugin import AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType, UnionType + from mypy_django_plugin import helpers, monkeypatch from mypy_django_plugin.config import Config from mypy_django_plugin.transformers import fields, init_create @@ -211,10 +212,6 @@ class DjangoPlugin(Plugin): def get_function_hook(self, fullname: str ) -> Optional[Callable[[FunctionContext], Type]]: - sym = self.lookup_fully_qualified(fullname) - if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FIELD_FULLNAME): - return fields.adjust_return_type_of_field_instantiation - if fullname == 'django.contrib.auth.get_user_model': return return_user_model_hook @@ -223,24 +220,34 @@ class DjangoPlugin(Plugin): return determine_proper_manager_type sym = self.lookup_fully_qualified(fullname) - if sym and isinstance(sym.node, TypeInfo): + if sym is not None and isinstance(sym.node, TypeInfo): + if sym.node.has_base(helpers.FIELD_FULLNAME): + return fields.adjust_return_type_of_field_instantiation + if sym.node.metadata.get('django', {}).get('generated_init'): return init_create.redefine_and_typecheck_model_init def get_method_hook(self, fullname: str ) -> Optional[Callable[[MethodContext], Type]]: + if fullname in {'django.apps.registry.Apps.get_model', + 'django.db.migrations.state.StateApps.get_model'}: + return determine_model_cls_from_string_for_migrations + manager_classes = self._get_current_manager_bases() class_fullname, _, method_name = fullname.rpartition('.') if class_fullname in manager_classes and method_name == 'create': return init_create.redefine_and_typecheck_model_create - - if fullname in {'django.apps.registry.Apps.get_model', - 'django.db.migrations.state.StateApps.get_model'}: - return determine_model_cls_from_string_for_migrations return None def get_base_class_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: + if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS: + settings_modules = ['django.conf.global_settings'] + if self.django_settings_module: + settings_modules.append(self.django_settings_module) + return AddSettingValuesToDjangoConfObject(settings_modules, + self.config.ignore_missing_settings) + if fullname in self._get_current_model_bases(): return transform_model_class @@ -250,17 +257,13 @@ class DjangoPlugin(Plugin): if fullname in self._get_current_modelform_bases(): return transform_modelform_class - if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS: - settings_modules = ['django.conf.global_settings'] - if self.django_settings_module: - settings_modules.append(self.django_settings_module) - return AddSettingValuesToDjangoConfObject(settings_modules, - self.config.ignore_missing_settings) - return None def get_attribute_hook(self, fullname: str ) -> Optional[Callable[[AttributeContext], Type]]: + if fullname == 'builtins.object.id': + return return_integer_type_for_id_for_non_defined_primary_key_in_models + module, _, name = fullname.rpartition('.') sym = self.lookup_fully_qualified('django.conf.LazySettings') if sym and isinstance(sym.node, TypeInfo): @@ -268,9 +271,6 @@ class DjangoPlugin(Plugin): if module == 'builtins.object' and name in metadata: return ExtractSettingType(module_fullname=metadata[name]) - if fullname == 'builtins.object.id': - return return_integer_type_for_id_for_non_defined_primary_key_in_models - return extract_and_return_primary_key_of_bound_related_field_parameter