add support for get_user_model(), fixes #16

This commit is contained in:
Maxim Kurnikov
2019-02-13 15:56:21 +03:00
parent 2720b74242
commit b7f7713c5a
5 changed files with 128 additions and 28 deletions

View File

@@ -1,7 +1,9 @@
import typing
from typing import Dict, Optional
from mypy.nodes import Expression, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo
from mypy.checker import TypeChecker
from mypy.nodes import AssignmentStmt, Expression, ImportedName, Lvalue, MypyFile, NameExpr, Statement, SymbolNode, TypeInfo, \
ClassDef
from mypy.plugin import FunctionContext
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType
@@ -146,3 +148,40 @@ def get_argument_type_by_name(ctx: FunctionContext, name: str) -> Optional[Type]
# Either an error or no value passed.
return None
return arg_types[0]
def get_setting_expr(api: TypeChecker, setting_name: str) -> Optional[Expression]:
try:
settings_sym = api.modules['django.conf'].names['settings']
except KeyError:
return None
settings_type: TypeInfo = settings_sym.type.type
auth_user_model_sym = settings_type.get(setting_name)
if not auth_user_model_sym:
return None
module, _, name = auth_user_model_sym.fullname.rpartition('.')
if module not in api.modules:
return None
module_file = api.modules.get(module)
for name_expr, value_expr in iter_over_assignments(module_file):
if isinstance(name_expr, NameExpr) and name_expr.name == setting_name:
return value_expr
return None
def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
if isinstance(class_or_module, ClassDef):
statements = class_or_module.defs.body
else:
statements = class_or_module.defs
for stmt in statements:
if not isinstance(stmt, AssignmentStmt):
continue
if len(stmt.lvalues) > 1:
# not supported yet
continue
yield stmt.lvalues[0], stmt.rvalue

View File

@@ -5,13 +5,13 @@ from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo
from mypy.options import Options
from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import Instance, Type
from mypy.types import Instance, Type, TypeType
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.config import Config
from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class
from mypy_django_plugin.plugins.init_create import redefine_and_typecheck_model_init, redefine_and_typecheck_model_create
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations, get_string_value_from_expr
from mypy_django_plugin.plugins.models import process_model_class
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
from mypy_django_plugin.plugins.settings import AddSettingValuesToDjangoConfObject
@@ -56,6 +56,32 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
return ret
def return_user_model_hook(ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
setting_expr = helpers.get_setting_expr(api, 'AUTH_USER_MODEL')
if setting_expr is None:
return ctx.default_return_type
app_label, _, model_class_name = get_string_value_from_expr(setting_expr).rpartition('.')
if app_label is None:
return ctx.default_return_type
model_fullname = helpers.get_model_fullname(app_label, model_class_name,
all_modules=api.modules)
if model_fullname is None:
api.fail(f'"{app_label}.{model_class_name}" model class is not imported so far. Try to import it '
f'(under if TYPE_CHECKING) at the beginning of the current file',
context=ctx.context)
return ctx.default_return_type
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
all_modules=api.modules)
if model_info is None or not isinstance(model_info, TypeInfo):
return ctx.default_return_type
return TypeType(Instance(model_info, []))
class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
@@ -105,6 +131,9 @@ class DjangoPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == 'django.contrib.auth.get_user_model':
return return_user_model_hook
if fullname in {helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME,
helpers.MANYTOMANY_FIELD_FULLNAME}:

View File

@@ -10,6 +10,7 @@ from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny
from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import iter_over_assignments
@dataclasses.dataclass
@@ -55,16 +56,6 @@ class ModelClassInitializer(metaclass=ABCMeta):
raise NotImplementedError()
def iter_over_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, Expression]]:
for stmt in klass.defs.body:
if not isinstance(stmt, AssignmentStmt):
continue
if len(stmt.lvalues) > 1:
# not supported yet
continue
yield stmt.lvalues[0], stmt.rvalue
def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]:
for lvalue, rvalue in iter_over_assignments(klass):
if isinstance(rvalue, CallExpr):