diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index e6b208f..7454b11 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -27,10 +27,8 @@ def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> O return all_modules.get(models_module) -def get_model_fullname_from_string(expr: StrExpr, - all_modules: Dict[str, MypyFile]) -> Optional[str]: - app_name, model_name = expr.value.split('.') - +def get_model_fullname(app_name: str, model_name: str, + all_modules: Dict[str, MypyFile]) -> Optional[str]: models_file = get_models_file(app_name, all_modules) if models_file is None: # not imported so far, not supported @@ -47,6 +45,21 @@ def get_model_fullname_from_string(expr: StrExpr, return None +class InvalidModelString(ValueError): + def __init__(self, model_string: str): + self.model_string = model_string + + +def get_model_fullname_from_string(expr: StrExpr, + all_modules: Dict[str, MypyFile]) -> Optional[str]: + model_string = expr.value + if '.' not in model_string: + raise InvalidModelString(model_string) + + app_name, model_name = model_string.split('.') + return get_model_fullname(app_name, model_name, all_modules) + + def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]: if '.' not in name: return None diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index eb58acc..98d23b1 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -4,11 +4,12 @@ from typing import Callable, Optional, cast, Dict from mypy.checker import TypeChecker from mypy.nodes import TypeInfo from mypy.options import Options -from mypy.plugin import Plugin, FunctionContext, ClassDefContext +from mypy.plugin import Plugin, FunctionContext, ClassDefContext, MethodContext from mypy.types import Type, Instance from mypy_django_plugin import helpers, monkeypatch from mypy_django_plugin.plugins.fields import determine_type_of_array_field +from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations from mypy_django_plugin.plugins.models import process_model_class from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with from mypy_django_plugin.plugins.settings import DjangoConfSettingsInitializerHook @@ -97,6 +98,12 @@ class DjangoPlugin(Plugin): manager_bases = self.get_current_manager_bases() if fullname in manager_bases: return determine_proper_manager_type + + def get_method_hook(self, fullname: str + ) -> Optional[Callable[[MethodContext], Type]]: + if fullname in {'django.apps.registry.Apps.get_model', + 'django.db.migrations.state.StateApps.get_model'}: + return determine_model_cls_from_string_for_migrations return None def get_base_class_hook(self, fullname: str diff --git a/mypy_django_plugin/plugins/migrations.py b/mypy_django_plugin/plugins/migrations.py new file mode 100644 index 0000000..26ca8c2 --- /dev/null +++ b/mypy_django_plugin/plugins/migrations.py @@ -0,0 +1,24 @@ +from typing import cast + +from mypy.checker import TypeChecker +from mypy.nodes import TypeInfo +from mypy.plugin import MethodContext +from mypy.types import Type, Instance, TypeType + +from mypy_django_plugin import helpers + + +def determine_model_cls_from_string_for_migrations(ctx: MethodContext) -> Type: + app_label = ctx.args[ctx.callee_arg_names.index('app_label')][0].value + model_name = ctx.args[ctx.callee_arg_names.index('model_name')][0].value + + api = cast(TypeChecker, ctx.api) + model_fullname = helpers.get_model_fullname(app_label, model_name, all_modules=api.modules) + + if model_fullname is None: + return ctx.default_return_type + model_info = helpers.lookup_fully_qualified_generic(model_fullname, + all_modules=api.modules) + if model_info is None or not isinstance(model_info, TypeInfo): + return ctx.default_return_type + return TypeType(Instance(model_info, [])) diff --git a/mypy_django_plugin/plugins/models.py b/mypy_django_plugin/plugins/models.py index 8afbb15..438e006 100644 --- a/mypy_django_plugin/plugins/models.py +++ b/mypy_django_plugin/plugins/models.py @@ -3,7 +3,7 @@ from abc import abstractmethod, ABCMeta from typing import cast, Iterator, Tuple, Optional, Dict from mypy.nodes import ClassDef, AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr, MDEF, TypeInfo, Var, SymbolTableNode, \ - Lvalue, Expression, MypyFile + Lvalue, Expression, MypyFile, Context from mypy.plugin import ClassDefContext from mypy.semanal import SemanticAnalyzerPass2 from mypy.types import Instance @@ -128,9 +128,15 @@ class AddRelatedManagers(ModelClassInitializer): for defn in iter_over_classdefs(module_file): for lvalue, rvalue in iter_call_assignments(defn): if is_related_field(rvalue, module_file): - ref_to_fullname = extract_ref_to_fullname(rvalue, - module_file=module_file, - all_modules=self.api.modules) + try: + ref_to_fullname = extract_ref_to_fullname(rvalue, + module_file=module_file, + all_modules=self.api.modules) + except helpers.InvalidModelString as exc: + self.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', + Context(line=rvalue.line)) + return None + if self.model_classdef.fullname == ref_to_fullname: if 'related_name' in rvalue.arg_names: related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')] diff --git a/mypy_django_plugin/plugins/related_fields.py b/mypy_django_plugin/plugins/related_fields.py index e1a3dd3..afdcf85 100644 --- a/mypy_django_plugin/plugins/related_fields.py +++ b/mypy_django_plugin/plugins/related_fields.py @@ -2,7 +2,7 @@ import typing from typing import Optional, cast from mypy.checker import TypeChecker -from mypy.nodes import StrExpr, TypeInfo +from mypy.nodes import StrExpr, TypeInfo, Context from mypy.plugin import FunctionContext from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny @@ -57,7 +57,12 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: def extract_to_parameter_as_get_ret_type_for_related_field(ctx: FunctionContext) -> Type: - referred_to_type = get_valid_to_value_or_none(ctx) + try: + referred_to_type = get_valid_to_value_or_none(ctx) + except helpers.InvalidModelString as exc: + ctx.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', ctx.context) + return fill_typevars_with_any(ctx.default_return_type) + if referred_to_type is None: # couldn't extract to= value return fill_typevars_with_any(ctx.default_return_type) diff --git a/test-data/typecheck/migrations.test b/test-data/typecheck/migrations.test new file mode 100644 index 0000000..60c6cf4 --- /dev/null +++ b/test-data/typecheck/migrations.test @@ -0,0 +1,31 @@ +[CASE registry_apps_get_model] +from django.apps.registry import Apps +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from myapp.models import User +apps = Apps() +model_cls = apps.get_model('myapp', 'User') +reveal_type(model_cls) # E: Revealed type is 'Type[myapp.models.User]' +reveal_type(model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[myapp.models.User]' + +[file myapp/__init__.py] +[file myapp/models.py] +from django.db import models +class User(models.Model): + pass + +[CASE state_apps_get_model] +from django.db.migrations.state import StateApps +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from myapp.models import User +apps = StateApps([], {}) +model_cls = apps.get_model('myapp', 'User') +reveal_type(model_cls) # E: Revealed type is 'Type[myapp.models.User]' +reveal_type(model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[myapp.models.User]' + +[file myapp/__init__.py] +[file myapp/models.py] +from django.db import models +class User(models.Model): + pass