mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-07 04:34:29 +08:00
add support for Apps.get_model for migrations
This commit is contained in:
@@ -27,10 +27,8 @@ def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> O
|
|||||||
return all_modules.get(models_module)
|
return all_modules.get(models_module)
|
||||||
|
|
||||||
|
|
||||||
def get_model_fullname_from_string(expr: StrExpr,
|
def get_model_fullname(app_name: str, model_name: str,
|
||||||
all_modules: Dict[str, MypyFile]) -> Optional[str]:
|
all_modules: Dict[str, MypyFile]) -> Optional[str]:
|
||||||
app_name, model_name = expr.value.split('.')
|
|
||||||
|
|
||||||
models_file = get_models_file(app_name, all_modules)
|
models_file = get_models_file(app_name, all_modules)
|
||||||
if models_file is None:
|
if models_file is None:
|
||||||
# not imported so far, not supported
|
# not imported so far, not supported
|
||||||
@@ -47,6 +45,21 @@ def get_model_fullname_from_string(expr: StrExpr,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidModelString(ValueError):
|
||||||
|
def __init__(self, model_string: str):
|
||||||
|
self.model_string = model_string
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_fullname_from_string(expr: StrExpr,
|
||||||
|
all_modules: Dict[str, MypyFile]) -> Optional[str]:
|
||||||
|
model_string = expr.value
|
||||||
|
if '.' not in model_string:
|
||||||
|
raise InvalidModelString(model_string)
|
||||||
|
|
||||||
|
app_name, model_name = model_string.split('.')
|
||||||
|
return get_model_fullname(app_name, model_name, all_modules)
|
||||||
|
|
||||||
|
|
||||||
def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]:
|
def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]:
|
||||||
if '.' not in name:
|
if '.' not in name:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ from typing import Callable, Optional, cast, Dict
|
|||||||
from mypy.checker import TypeChecker
|
from mypy.checker import TypeChecker
|
||||||
from mypy.nodes import TypeInfo
|
from mypy.nodes import TypeInfo
|
||||||
from mypy.options import Options
|
from mypy.options import Options
|
||||||
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
|
from mypy.plugin import Plugin, FunctionContext, ClassDefContext, MethodContext
|
||||||
from mypy.types import Type, Instance
|
from mypy.types import Type, Instance
|
||||||
|
|
||||||
from mypy_django_plugin import helpers, monkeypatch
|
from mypy_django_plugin import helpers, monkeypatch
|
||||||
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
|
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
|
||||||
|
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
|
||||||
from mypy_django_plugin.plugins.models import process_model_class
|
from mypy_django_plugin.plugins.models import process_model_class
|
||||||
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
|
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
|
||||||
from mypy_django_plugin.plugins.settings import DjangoConfSettingsInitializerHook
|
from mypy_django_plugin.plugins.settings import DjangoConfSettingsInitializerHook
|
||||||
@@ -97,6 +98,12 @@ class DjangoPlugin(Plugin):
|
|||||||
manager_bases = self.get_current_manager_bases()
|
manager_bases = self.get_current_manager_bases()
|
||||||
if fullname in manager_bases:
|
if fullname in manager_bases:
|
||||||
return determine_proper_manager_type
|
return determine_proper_manager_type
|
||||||
|
|
||||||
|
def get_method_hook(self, fullname: str
|
||||||
|
) -> Optional[Callable[[MethodContext], Type]]:
|
||||||
|
if fullname in {'django.apps.registry.Apps.get_model',
|
||||||
|
'django.db.migrations.state.StateApps.get_model'}:
|
||||||
|
return determine_model_cls_from_string_for_migrations
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_base_class_hook(self, fullname: str
|
def get_base_class_hook(self, fullname: str
|
||||||
|
|||||||
24
mypy_django_plugin/plugins/migrations.py
Normal file
24
mypy_django_plugin/plugins/migrations.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from mypy.checker import TypeChecker
|
||||||
|
from mypy.nodes import TypeInfo
|
||||||
|
from mypy.plugin import MethodContext
|
||||||
|
from mypy.types import Type, Instance, TypeType
|
||||||
|
|
||||||
|
from mypy_django_plugin import helpers
|
||||||
|
|
||||||
|
|
||||||
|
def determine_model_cls_from_string_for_migrations(ctx: MethodContext) -> Type:
|
||||||
|
app_label = ctx.args[ctx.callee_arg_names.index('app_label')][0].value
|
||||||
|
model_name = ctx.args[ctx.callee_arg_names.index('model_name')][0].value
|
||||||
|
|
||||||
|
api = cast(TypeChecker, ctx.api)
|
||||||
|
model_fullname = helpers.get_model_fullname(app_label, model_name, all_modules=api.modules)
|
||||||
|
|
||||||
|
if model_fullname is None:
|
||||||
|
return ctx.default_return_type
|
||||||
|
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
|
||||||
|
all_modules=api.modules)
|
||||||
|
if model_info is None or not isinstance(model_info, TypeInfo):
|
||||||
|
return ctx.default_return_type
|
||||||
|
return TypeType(Instance(model_info, []))
|
||||||
@@ -3,7 +3,7 @@ from abc import abstractmethod, ABCMeta
|
|||||||
from typing import cast, Iterator, Tuple, Optional, Dict
|
from typing import cast, Iterator, Tuple, Optional, Dict
|
||||||
|
|
||||||
from mypy.nodes import ClassDef, AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr, MDEF, TypeInfo, Var, SymbolTableNode, \
|
from mypy.nodes import ClassDef, AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr, MDEF, TypeInfo, Var, SymbolTableNode, \
|
||||||
Lvalue, Expression, MypyFile
|
Lvalue, Expression, MypyFile, Context
|
||||||
from mypy.plugin import ClassDefContext
|
from mypy.plugin import ClassDefContext
|
||||||
from mypy.semanal import SemanticAnalyzerPass2
|
from mypy.semanal import SemanticAnalyzerPass2
|
||||||
from mypy.types import Instance
|
from mypy.types import Instance
|
||||||
@@ -128,9 +128,15 @@ class AddRelatedManagers(ModelClassInitializer):
|
|||||||
for defn in iter_over_classdefs(module_file):
|
for defn in iter_over_classdefs(module_file):
|
||||||
for lvalue, rvalue in iter_call_assignments(defn):
|
for lvalue, rvalue in iter_call_assignments(defn):
|
||||||
if is_related_field(rvalue, module_file):
|
if is_related_field(rvalue, module_file):
|
||||||
|
try:
|
||||||
ref_to_fullname = extract_ref_to_fullname(rvalue,
|
ref_to_fullname = extract_ref_to_fullname(rvalue,
|
||||||
module_file=module_file,
|
module_file=module_file,
|
||||||
all_modules=self.api.modules)
|
all_modules=self.api.modules)
|
||||||
|
except helpers.InvalidModelString as exc:
|
||||||
|
self.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}',
|
||||||
|
Context(line=rvalue.line))
|
||||||
|
return None
|
||||||
|
|
||||||
if self.model_classdef.fullname == ref_to_fullname:
|
if self.model_classdef.fullname == ref_to_fullname:
|
||||||
if 'related_name' in rvalue.arg_names:
|
if 'related_name' in rvalue.arg_names:
|
||||||
related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')]
|
related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import typing
|
|||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from mypy.checker import TypeChecker
|
from mypy.checker import TypeChecker
|
||||||
from mypy.nodes import StrExpr, TypeInfo
|
from mypy.nodes import StrExpr, TypeInfo, Context
|
||||||
from mypy.plugin import FunctionContext
|
from mypy.plugin import FunctionContext
|
||||||
from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny
|
from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny
|
||||||
|
|
||||||
@@ -57,7 +57,12 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
|
|||||||
|
|
||||||
|
|
||||||
def extract_to_parameter_as_get_ret_type_for_related_field(ctx: FunctionContext) -> Type:
|
def extract_to_parameter_as_get_ret_type_for_related_field(ctx: FunctionContext) -> Type:
|
||||||
|
try:
|
||||||
referred_to_type = get_valid_to_value_or_none(ctx)
|
referred_to_type = get_valid_to_value_or_none(ctx)
|
||||||
|
except helpers.InvalidModelString as exc:
|
||||||
|
ctx.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', ctx.context)
|
||||||
|
return fill_typevars_with_any(ctx.default_return_type)
|
||||||
|
|
||||||
if referred_to_type is None:
|
if referred_to_type is None:
|
||||||
# couldn't extract to= value
|
# couldn't extract to= value
|
||||||
return fill_typevars_with_any(ctx.default_return_type)
|
return fill_typevars_with_any(ctx.default_return_type)
|
||||||
|
|||||||
31
test-data/typecheck/migrations.test
Normal file
31
test-data/typecheck/migrations.test
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
[CASE registry_apps_get_model]
|
||||||
|
from django.apps.registry import Apps
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from myapp.models import User
|
||||||
|
apps = Apps()
|
||||||
|
model_cls = apps.get_model('myapp', 'User')
|
||||||
|
reveal_type(model_cls) # E: Revealed type is 'Type[myapp.models.User]'
|
||||||
|
reveal_type(model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[myapp.models.User]'
|
||||||
|
|
||||||
|
[file myapp/__init__.py]
|
||||||
|
[file myapp/models.py]
|
||||||
|
from django.db import models
|
||||||
|
class User(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
[CASE state_apps_get_model]
|
||||||
|
from django.db.migrations.state import StateApps
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from myapp.models import User
|
||||||
|
apps = StateApps([], {})
|
||||||
|
model_cls = apps.get_model('myapp', 'User')
|
||||||
|
reveal_type(model_cls) # E: Revealed type is 'Type[myapp.models.User]'
|
||||||
|
reveal_type(model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[myapp.models.User]'
|
||||||
|
|
||||||
|
[file myapp/__init__.py]
|
||||||
|
[file myapp/models.py]
|
||||||
|
from django.db import models
|
||||||
|
class User(models.Model):
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user