mirror of
https://github.com/davidhalter/django-stubs.git
synced 2026-05-24 17:28:41 +08:00
add support for get_user_model(), fixes #16
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import typing
|
||||
from typing import Dict, Optional
|
||||
|
||||
from mypy.nodes import Expression, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import AssignmentStmt, Expression, ImportedName, Lvalue, MypyFile, NameExpr, Statement, SymbolNode, TypeInfo, \
|
||||
ClassDef
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType
|
||||
|
||||
@@ -146,3 +148,40 @@ def get_argument_type_by_name(ctx: FunctionContext, name: str) -> Optional[Type]
|
||||
# Either an error or no value passed.
|
||||
return None
|
||||
return arg_types[0]
|
||||
|
||||
|
||||
def get_setting_expr(api: TypeChecker, setting_name: str) -> Optional[Expression]:
|
||||
try:
|
||||
settings_sym = api.modules['django.conf'].names['settings']
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
settings_type: TypeInfo = settings_sym.type.type
|
||||
auth_user_model_sym = settings_type.get(setting_name)
|
||||
if not auth_user_model_sym:
|
||||
return None
|
||||
|
||||
module, _, name = auth_user_model_sym.fullname.rpartition('.')
|
||||
if module not in api.modules:
|
||||
return None
|
||||
|
||||
module_file = api.modules.get(module)
|
||||
for name_expr, value_expr in iter_over_assignments(module_file):
|
||||
if isinstance(name_expr, NameExpr) and name_expr.name == setting_name:
|
||||
return value_expr
|
||||
return None
|
||||
|
||||
|
||||
def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
|
||||
if isinstance(class_or_module, ClassDef):
|
||||
statements = class_or_module.defs.body
|
||||
else:
|
||||
statements = class_or_module.defs
|
||||
|
||||
for stmt in statements:
|
||||
if not isinstance(stmt, AssignmentStmt):
|
||||
continue
|
||||
if len(stmt.lvalues) > 1:
|
||||
# not supported yet
|
||||
continue
|
||||
yield stmt.lvalues[0], stmt.rvalue
|
||||
|
||||
Reference in New Issue
Block a user