mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-06 20:24:31 +08:00
add support for Apps.get_model for migrations
This commit is contained in:
@@ -27,10 +27,8 @@ def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> O
|
||||
return all_modules.get(models_module)
|
||||
|
||||
|
||||
def get_model_fullname_from_string(expr: StrExpr,
|
||||
def get_model_fullname(app_name: str, model_name: str,
|
||||
all_modules: Dict[str, MypyFile]) -> Optional[str]:
|
||||
app_name, model_name = expr.value.split('.')
|
||||
|
||||
models_file = get_models_file(app_name, all_modules)
|
||||
if models_file is None:
|
||||
# not imported so far, not supported
|
||||
@@ -47,6 +45,21 @@ def get_model_fullname_from_string(expr: StrExpr,
|
||||
return None
|
||||
|
||||
|
||||
class InvalidModelString(ValueError):
|
||||
def __init__(self, model_string: str):
|
||||
self.model_string = model_string
|
||||
|
||||
|
||||
def get_model_fullname_from_string(expr: StrExpr,
|
||||
all_modules: Dict[str, MypyFile]) -> Optional[str]:
|
||||
model_string = expr.value
|
||||
if '.' not in model_string:
|
||||
raise InvalidModelString(model_string)
|
||||
|
||||
app_name, model_name = model_string.split('.')
|
||||
return get_model_fullname(app_name, model_name, all_modules)
|
||||
|
||||
|
||||
def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]:
|
||||
if '.' not in name:
|
||||
return None
|
||||
|
||||
@@ -4,11 +4,12 @@ from typing import Callable, Optional, cast, Dict
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.options import Options
|
||||
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
|
||||
from mypy.plugin import Plugin, FunctionContext, ClassDefContext, MethodContext
|
||||
from mypy.types import Type, Instance
|
||||
|
||||
from mypy_django_plugin import helpers, monkeypatch
|
||||
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
|
||||
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
|
||||
from mypy_django_plugin.plugins.models import process_model_class
|
||||
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
|
||||
from mypy_django_plugin.plugins.settings import DjangoConfSettingsInitializerHook
|
||||
@@ -97,6 +98,12 @@ class DjangoPlugin(Plugin):
|
||||
manager_bases = self.get_current_manager_bases()
|
||||
if fullname in manager_bases:
|
||||
return determine_proper_manager_type
|
||||
|
||||
def get_method_hook(self, fullname: str
|
||||
) -> Optional[Callable[[MethodContext], Type]]:
|
||||
if fullname in {'django.apps.registry.Apps.get_model',
|
||||
'django.db.migrations.state.StateApps.get_model'}:
|
||||
return determine_model_cls_from_string_for_migrations
|
||||
return None
|
||||
|
||||
def get_base_class_hook(self, fullname: str
|
||||
|
||||
24
mypy_django_plugin/plugins/migrations.py
Normal file
24
mypy_django_plugin/plugins/migrations.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import cast
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.plugin import MethodContext
|
||||
from mypy.types import Type, Instance, TypeType
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def determine_model_cls_from_string_for_migrations(ctx: MethodContext) -> Type:
|
||||
app_label = ctx.args[ctx.callee_arg_names.index('app_label')][0].value
|
||||
model_name = ctx.args[ctx.callee_arg_names.index('model_name')][0].value
|
||||
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
model_fullname = helpers.get_model_fullname(app_label, model_name, all_modules=api.modules)
|
||||
|
||||
if model_fullname is None:
|
||||
return ctx.default_return_type
|
||||
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
|
||||
all_modules=api.modules)
|
||||
if model_info is None or not isinstance(model_info, TypeInfo):
|
||||
return ctx.default_return_type
|
||||
return TypeType(Instance(model_info, []))
|
||||
@@ -3,7 +3,7 @@ from abc import abstractmethod, ABCMeta
|
||||
from typing import cast, Iterator, Tuple, Optional, Dict
|
||||
|
||||
from mypy.nodes import ClassDef, AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr, MDEF, TypeInfo, Var, SymbolTableNode, \
|
||||
Lvalue, Expression, MypyFile
|
||||
Lvalue, Expression, MypyFile, Context
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.semanal import SemanticAnalyzerPass2
|
||||
from mypy.types import Instance
|
||||
@@ -128,9 +128,15 @@ class AddRelatedManagers(ModelClassInitializer):
|
||||
for defn in iter_over_classdefs(module_file):
|
||||
for lvalue, rvalue in iter_call_assignments(defn):
|
||||
if is_related_field(rvalue, module_file):
|
||||
try:
|
||||
ref_to_fullname = extract_ref_to_fullname(rvalue,
|
||||
module_file=module_file,
|
||||
all_modules=self.api.modules)
|
||||
except helpers.InvalidModelString as exc:
|
||||
self.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}',
|
||||
Context(line=rvalue.line))
|
||||
return None
|
||||
|
||||
if self.model_classdef.fullname == ref_to_fullname:
|
||||
if 'related_name' in rvalue.arg_names:
|
||||
related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')]
|
||||
|
||||
@@ -2,7 +2,7 @@ import typing
|
||||
from typing import Optional, cast
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import StrExpr, TypeInfo
|
||||
from mypy.nodes import StrExpr, TypeInfo, Context
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny
|
||||
|
||||
@@ -57,7 +57,12 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
|
||||
|
||||
|
||||
def extract_to_parameter_as_get_ret_type_for_related_field(ctx: FunctionContext) -> Type:
|
||||
try:
|
||||
referred_to_type = get_valid_to_value_or_none(ctx)
|
||||
except helpers.InvalidModelString as exc:
|
||||
ctx.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', ctx.context)
|
||||
return fill_typevars_with_any(ctx.default_return_type)
|
||||
|
||||
if referred_to_type is None:
|
||||
# couldn't extract to= value
|
||||
return fill_typevars_with_any(ctx.default_return_type)
|
||||
|
||||
31
test-data/typecheck/migrations.test
Normal file
31
test-data/typecheck/migrations.test
Normal file
@@ -0,0 +1,31 @@
|
||||
[CASE registry_apps_get_model]
|
||||
from django.apps.registry import Apps
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from myapp.models import User
|
||||
apps = Apps()
|
||||
model_cls = apps.get_model('myapp', 'User')
|
||||
reveal_type(model_cls) # E: Revealed type is 'Type[myapp.models.User]'
|
||||
reveal_type(model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[myapp.models.User]'
|
||||
|
||||
[file myapp/__init__.py]
|
||||
[file myapp/models.py]
|
||||
from django.db import models
|
||||
class User(models.Model):
|
||||
pass
|
||||
|
||||
[CASE state_apps_get_model]
|
||||
from django.db.migrations.state import StateApps
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from myapp.models import User
|
||||
apps = StateApps([], {})
|
||||
model_cls = apps.get_model('myapp', 'User')
|
||||
reveal_type(model_cls) # E: Revealed type is 'Type[myapp.models.User]'
|
||||
reveal_type(model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[myapp.models.User]'
|
||||
|
||||
[file myapp/__init__.py]
|
||||
[file myapp/models.py]
|
||||
from django.db import models
|
||||
class User(models.Model):
|
||||
pass
|
||||
Reference in New Issue
Block a user