mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-08 21:14:49 +08:00
add support for get_user_model(), fixes #16
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
import typing
|
import typing
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from mypy.nodes import Expression, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo
|
from mypy.checker import TypeChecker
|
||||||
|
from mypy.nodes import AssignmentStmt, Expression, ImportedName, Lvalue, MypyFile, NameExpr, Statement, SymbolNode, TypeInfo, \
|
||||||
|
ClassDef
|
||||||
from mypy.plugin import FunctionContext
|
from mypy.plugin import FunctionContext
|
||||||
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType
|
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType
|
||||||
|
|
||||||
@@ -146,3 +148,40 @@ def get_argument_type_by_name(ctx: FunctionContext, name: str) -> Optional[Type]
|
|||||||
# Either an error or no value passed.
|
# Either an error or no value passed.
|
||||||
return None
|
return None
|
||||||
return arg_types[0]
|
return arg_types[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_setting_expr(api: TypeChecker, setting_name: str) -> Optional[Expression]:
|
||||||
|
try:
|
||||||
|
settings_sym = api.modules['django.conf'].names['settings']
|
||||||
|
except KeyError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
settings_type: TypeInfo = settings_sym.type.type
|
||||||
|
auth_user_model_sym = settings_type.get(setting_name)
|
||||||
|
if not auth_user_model_sym:
|
||||||
|
return None
|
||||||
|
|
||||||
|
module, _, name = auth_user_model_sym.fullname.rpartition('.')
|
||||||
|
if module not in api.modules:
|
||||||
|
return None
|
||||||
|
|
||||||
|
module_file = api.modules.get(module)
|
||||||
|
for name_expr, value_expr in iter_over_assignments(module_file):
|
||||||
|
if isinstance(name_expr, NameExpr) and name_expr.name == setting_name:
|
||||||
|
return value_expr
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
|
||||||
|
if isinstance(class_or_module, ClassDef):
|
||||||
|
statements = class_or_module.defs.body
|
||||||
|
else:
|
||||||
|
statements = class_or_module.defs
|
||||||
|
|
||||||
|
for stmt in statements:
|
||||||
|
if not isinstance(stmt, AssignmentStmt):
|
||||||
|
continue
|
||||||
|
if len(stmt.lvalues) > 1:
|
||||||
|
# not supported yet
|
||||||
|
continue
|
||||||
|
yield stmt.lvalues[0], stmt.rvalue
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ from mypy.checker import TypeChecker
|
|||||||
from mypy.nodes import TypeInfo
|
from mypy.nodes import TypeInfo
|
||||||
from mypy.options import Options
|
from mypy.options import Options
|
||||||
from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
|
from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
|
||||||
from mypy.types import Instance, Type
|
from mypy.types import Instance, Type, TypeType
|
||||||
|
|
||||||
from mypy_django_plugin import helpers, monkeypatch
|
from mypy_django_plugin import helpers, monkeypatch
|
||||||
from mypy_django_plugin.config import Config
|
from mypy_django_plugin.config import Config
|
||||||
from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class
|
from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class
|
||||||
from mypy_django_plugin.plugins.init_create import redefine_and_typecheck_model_init, redefine_and_typecheck_model_create
|
from mypy_django_plugin.plugins.init_create import redefine_and_typecheck_model_init, redefine_and_typecheck_model_create
|
||||||
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
|
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations, get_string_value_from_expr
|
||||||
from mypy_django_plugin.plugins.models import process_model_class
|
from mypy_django_plugin.plugins.models import process_model_class
|
||||||
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
|
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
|
||||||
from mypy_django_plugin.plugins.settings import AddSettingValuesToDjangoConfObject
|
from mypy_django_plugin.plugins.settings import AddSettingValuesToDjangoConfObject
|
||||||
@@ -56,6 +56,32 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def return_user_model_hook(ctx: FunctionContext) -> Type:
|
||||||
|
api = cast(TypeChecker, ctx.api)
|
||||||
|
setting_expr = helpers.get_setting_expr(api, 'AUTH_USER_MODEL')
|
||||||
|
if setting_expr is None:
|
||||||
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
app_label, _, model_class_name = get_string_value_from_expr(setting_expr).rpartition('.')
|
||||||
|
if app_label is None:
|
||||||
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
model_fullname = helpers.get_model_fullname(app_label, model_class_name,
|
||||||
|
all_modules=api.modules)
|
||||||
|
if model_fullname is None:
|
||||||
|
api.fail(f'"{app_label}.{model_class_name}" model class is not imported so far. Try to import it '
|
||||||
|
f'(under if TYPE_CHECKING) at the beginning of the current file',
|
||||||
|
context=ctx.context)
|
||||||
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
|
||||||
|
all_modules=api.modules)
|
||||||
|
if model_info is None or not isinstance(model_info, TypeInfo):
|
||||||
|
return ctx.default_return_type
|
||||||
|
return TypeType(Instance(model_info, []))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DjangoPlugin(Plugin):
|
class DjangoPlugin(Plugin):
|
||||||
def __init__(self, options: Options) -> None:
|
def __init__(self, options: Options) -> None:
|
||||||
super().__init__(options)
|
super().__init__(options)
|
||||||
@@ -105,6 +131,9 @@ class DjangoPlugin(Plugin):
|
|||||||
|
|
||||||
def get_function_hook(self, fullname: str
|
def get_function_hook(self, fullname: str
|
||||||
) -> Optional[Callable[[FunctionContext], Type]]:
|
) -> Optional[Callable[[FunctionContext], Type]]:
|
||||||
|
if fullname == 'django.contrib.auth.get_user_model':
|
||||||
|
return return_user_model_hook
|
||||||
|
|
||||||
if fullname in {helpers.FOREIGN_KEY_FULLNAME,
|
if fullname in {helpers.FOREIGN_KEY_FULLNAME,
|
||||||
helpers.ONETOONE_FIELD_FULLNAME,
|
helpers.ONETOONE_FIELD_FULLNAME,
|
||||||
helpers.MANYTOMANY_FIELD_FULLNAME}:
|
helpers.MANYTOMANY_FIELD_FULLNAME}:
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from mypy.semanal import SemanticAnalyzerPass2
|
|||||||
from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny
|
from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny
|
||||||
|
|
||||||
from mypy_django_plugin import helpers
|
from mypy_django_plugin import helpers
|
||||||
|
from mypy_django_plugin.helpers import iter_over_assignments
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -55,16 +56,6 @@ class ModelClassInitializer(metaclass=ABCMeta):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def iter_over_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, Expression]]:
|
|
||||||
for stmt in klass.defs.body:
|
|
||||||
if not isinstance(stmt, AssignmentStmt):
|
|
||||||
continue
|
|
||||||
if len(stmt.lvalues) > 1:
|
|
||||||
# not supported yet
|
|
||||||
continue
|
|
||||||
yield stmt.lvalues[0], stmt.rvalue
|
|
||||||
|
|
||||||
|
|
||||||
def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]:
|
def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]:
|
||||||
for lvalue, rvalue in iter_over_assignments(klass):
|
for lvalue, rvalue in iter_over_assignments(klass):
|
||||||
if isinstance(rvalue, CallExpr):
|
if isinstance(rvalue, CallExpr):
|
||||||
|
|||||||
@@ -137,18 +137,3 @@ class AbstractBase2(models.Model):
|
|||||||
class Child(AbstractBase1, AbstractBase2):
|
class Child(AbstractBase1, AbstractBase2):
|
||||||
pass
|
pass
|
||||||
[out]
|
[out]
|
||||||
|
|
||||||
[CASE get_object_or_404_returns_proper_types]
|
|
||||||
from django.shortcuts import get_object_or_404, get_list_or_404
|
|
||||||
from django.db import models
|
|
||||||
|
|
||||||
class MyModel(models.Model):
|
|
||||||
pass
|
|
||||||
reveal_type(get_object_or_404(MyModel)) # E: Revealed type is 'main.MyModel*'
|
|
||||||
reveal_type(get_object_or_404(MyModel.objects)) # E: Revealed type is 'main.MyModel*'
|
|
||||||
reveal_type(get_object_or_404(MyModel.objects.get_queryset())) # E: Revealed type is 'main.MyModel*'
|
|
||||||
|
|
||||||
reveal_type(get_list_or_404(MyModel)) # E: Revealed type is 'builtins.list[main.MyModel*]'
|
|
||||||
reveal_type(get_list_or_404(MyModel.objects)) # E: Revealed type is 'builtins.list[main.MyModel*]'
|
|
||||||
reveal_type(get_list_or_404(MyModel.objects.get_queryset())) # E: Revealed type is 'builtins.list[main.MyModel*]'
|
|
||||||
[out]
|
|
||||||
56
test-data/typecheck/shortcuts.test
Normal file
56
test-data/typecheck/shortcuts.test
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
[CASE get_object_or_404_returns_proper_types]
|
||||||
|
from django.shortcuts import get_object_or_404, get_list_or_404
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class MyModel(models.Model):
|
||||||
|
pass
|
||||||
|
reveal_type(get_object_or_404(MyModel)) # E: Revealed type is 'main.MyModel*'
|
||||||
|
reveal_type(get_object_or_404(MyModel.objects)) # E: Revealed type is 'main.MyModel*'
|
||||||
|
reveal_type(get_object_or_404(MyModel.objects.get_queryset())) # E: Revealed type is 'main.MyModel*'
|
||||||
|
|
||||||
|
reveal_type(get_list_or_404(MyModel)) # E: Revealed type is 'builtins.list[main.MyModel*]'
|
||||||
|
reveal_type(get_list_or_404(MyModel.objects)) # E: Revealed type is 'builtins.list[main.MyModel*]'
|
||||||
|
reveal_type(get_list_or_404(MyModel.objects.get_queryset())) # E: Revealed type is 'builtins.list[main.MyModel*]'
|
||||||
|
[out]
|
||||||
|
|
||||||
|
[CASE get_user_model_returns_proper_class]
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from myapp.models import MyUser
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
|
||||||
|
UserModel = get_user_model()
|
||||||
|
reveal_type(UserModel.objects) # E: Revealed type is 'django.db.models.manager.Manager[myapp.models.MyUser]'
|
||||||
|
|
||||||
|
[env DJANGO_SETTINGS_MODULE=mysettings]
|
||||||
|
[file mysettings.py]
|
||||||
|
INSTALLED_APPS = ('myapp',)
|
||||||
|
AUTH_USER_MODEL = 'myapp.MyUser'
|
||||||
|
|
||||||
|
[file myapp/__init__.py]
|
||||||
|
[file myapp/models.py]
|
||||||
|
from django.db import models
|
||||||
|
class MyUser(models.Model):
|
||||||
|
pass
|
||||||
|
[out]
|
||||||
|
|
||||||
|
[CASE return_type_model_and_show_error_if_model_not_yet_imported]
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
|
||||||
|
UserModel = get_user_model()
|
||||||
|
reveal_type(UserModel.objects)
|
||||||
|
|
||||||
|
[env DJANGO_SETTINGS_MODULE=mysettings]
|
||||||
|
[file mysettings.py]
|
||||||
|
INSTALLED_APPS = ('myapp',)
|
||||||
|
AUTH_USER_MODEL = 'myapp.MyUser'
|
||||||
|
|
||||||
|
[file myapp/__init__.py]
|
||||||
|
[file myapp/models.py]
|
||||||
|
from django.db import models
|
||||||
|
class MyUser(models.Model):
|
||||||
|
pass
|
||||||
|
[out]
|
||||||
|
main:3: error: "myapp.MyUser" model class is not imported so far. Try to import it (under if TYPE_CHECKING) at the beginning of the current file
|
||||||
|
main:4: error: Revealed type is 'Any'
|
||||||
|
main:4: error: "Type[Model]" has no attribute "objects"
|
||||||
Reference in New Issue
Block a user