mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-06 20:24:31 +08:00
218 lines
8.1 KiB
Python
218 lines
8.1 KiB
Python
import typing
|
|
from typing import Dict, Optional
|
|
|
|
from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var, AssignmentStmt, \
|
|
CallExpr
|
|
from mypy.plugin import FunctionContext
|
|
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType
|
|
|
|
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
|
|
FIELD_FULLNAME = 'django.db.models.fields.Field'
|
|
GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey'
|
|
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
|
|
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
|
|
MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField'
|
|
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
|
|
|
|
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
|
|
BASE_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.BaseManager'
|
|
MANAGER_CLASS_FULLNAME = 'django.db.models.manager.Manager'
|
|
RELATED_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.RelatedManager'
|
|
|
|
MANAGER_CLASSES = {
|
|
MANAGER_CLASS_FULLNAME,
|
|
RELATED_MANAGER_CLASS_FULLNAME,
|
|
BASE_MANAGER_CLASS_FULLNAME,
|
|
QUERYSET_CLASS_FULLNAME
|
|
}
|
|
|
|
|
|
def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]:
|
|
models_module = '.'.join([app_name, 'models'])
|
|
return all_modules.get(models_module)
|
|
|
|
|
|
def get_model_fullname(app_name: str, model_name: str,
|
|
all_modules: Dict[str, MypyFile]) -> Optional[str]:
|
|
models_file = get_models_file(app_name, all_modules)
|
|
if models_file is None:
|
|
# not imported so far, not supported
|
|
return None
|
|
sym = models_file.names.get(model_name)
|
|
if not sym:
|
|
return None
|
|
|
|
if isinstance(sym.node, TypeInfo):
|
|
return sym.node.fullname()
|
|
elif isinstance(sym.node, ImportedName):
|
|
return sym.node.target_fullname
|
|
else:
|
|
return None
|
|
|
|
|
|
class InvalidModelString(ValueError):
|
|
def __init__(self, model_string: str):
|
|
self.model_string = model_string
|
|
|
|
|
|
class SelfReference(ValueError):
|
|
pass
|
|
|
|
|
|
def get_model_fullname_from_string(model_string: str,
|
|
all_modules: Dict[str, MypyFile]) -> Optional[str]:
|
|
if model_string == 'self':
|
|
raise SelfReference()
|
|
|
|
if '.' not in model_string:
|
|
raise InvalidModelString(model_string)
|
|
|
|
app_name, model_name = model_string.split('.')
|
|
return get_model_fullname(app_name, model_name, all_modules)
|
|
|
|
|
|
def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]:
|
|
if '.' not in name:
|
|
return None
|
|
module, cls_name = name.rsplit('.', 1)
|
|
|
|
module_file = all_modules.get(module)
|
|
if module_file is None:
|
|
return None
|
|
sym = module_file.names.get(cls_name)
|
|
if sym is None:
|
|
return None
|
|
return sym.node
|
|
|
|
|
|
def parse_bool(expr: Expression) -> Optional[bool]:
|
|
if isinstance(expr, NameExpr):
|
|
if expr.fullname == 'builtins.True':
|
|
return True
|
|
if expr.fullname == 'builtins.False':
|
|
return False
|
|
return None
|
|
|
|
|
|
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
|
|
return Instance(instance.type, args=new_typevars)
|
|
|
|
|
|
def fill_typevars_with_any(instance: Instance) -> Type:
|
|
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
|
|
|
|
|
|
def extract_typevar_value(tp: Instance, typevar_name: str) -> Type:
|
|
if typevar_name in {'_T', '_T_co'}:
|
|
if '_T' in tp.type.type_vars:
|
|
return tp.args[tp.type.type_vars.index('_T')]
|
|
if '_T_co' in tp.type.type_vars:
|
|
return tp.args[tp.type.type_vars.index('_T_co')]
|
|
return tp.args[tp.type.type_vars.index(typevar_name)]
|
|
|
|
|
|
def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance:
|
|
typevar_values: typing.List[Type] = []
|
|
for typevar_arg in type_to_fill.args:
|
|
if isinstance(typevar_arg, TypeVarType):
|
|
typevar_values.append(extract_typevar_value(tp, typevar_arg.name))
|
|
return reparametrize_with(type_to_fill, typevar_values)
|
|
|
|
|
|
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
|
|
if tp.type.has_base(FIELD_FULLNAME):
|
|
set_method = tp.type.get_method('__set__')
|
|
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
|
|
if 'value' in set_method.type.arg_names:
|
|
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
|
|
if isinstance(set_value_type, Instance):
|
|
set_value_type = fill_typevars(tp, set_value_type)
|
|
return set_value_type
|
|
elif isinstance(set_value_type, UnionType):
|
|
items_no_typevars = []
|
|
for item in set_value_type.items:
|
|
if isinstance(item, Instance):
|
|
item = fill_typevars(tp, item)
|
|
items_no_typevars.append(item)
|
|
return UnionType(items_no_typevars)
|
|
|
|
get_method = tp.type.get_method('__get__')
|
|
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
|
|
return get_method.type.ret_type
|
|
# GenericForeignKey
|
|
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
|
|
return AnyType(TypeOfAny.special_form)
|
|
return None
|
|
|
|
|
|
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
|
|
# only primary keys defined in current class for now
|
|
for stmt in model.defn.defs.body:
|
|
if isinstance(stmt, AssignmentStmt) and isinstance(stmt.rvalue, CallExpr):
|
|
name_expr = stmt.lvalues[0]
|
|
if isinstance(name_expr, NameExpr):
|
|
name = name_expr.name
|
|
if 'primary_key' in stmt.rvalue.arg_names:
|
|
is_primary_key = stmt.rvalue.args[stmt.rvalue.arg_names.index('primary_key')]
|
|
if is_primary_key:
|
|
return extract_field_setter_type(model.names[name].type)
|
|
return None
|
|
|
|
|
|
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
|
|
expected_types: Dict[str, Type] = {}
|
|
|
|
primary_key_type = extract_primary_key_type(model)
|
|
if not primary_key_type:
|
|
# no explicit primary key, set pk to Any and add id
|
|
primary_key_type = AnyType(TypeOfAny.special_form)
|
|
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
|
|
|
|
expected_types['pk'] = primary_key_type
|
|
|
|
for base in model.mro:
|
|
for name, sym in base.names.items():
|
|
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
|
|
tp = sym.node.type
|
|
field_type = extract_field_setter_type(tp)
|
|
if tp.type.fullname() in {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME}:
|
|
ref_to_model = tp.args[0]
|
|
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME):
|
|
primary_key_type = extract_primary_key_type(ref_to_model.type)
|
|
if not primary_key_type:
|
|
primary_key_type = AnyType(TypeOfAny.special_form)
|
|
expected_types[name + '_id'] = primary_key_type
|
|
if field_type:
|
|
expected_types[name] = field_type
|
|
return expected_types
|
|
|
|
|
|
def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]:
|
|
"""Return the expression for the specific argument.
|
|
|
|
This helper should only be used with non-star arguments.
|
|
"""
|
|
if name not in ctx.callee_arg_names:
|
|
return None
|
|
idx = ctx.callee_arg_names.index(name)
|
|
args = ctx.args[idx]
|
|
if len(args) != 1:
|
|
# Either an error or no value passed.
|
|
return None
|
|
return args[0]
|
|
|
|
|
|
def get_argument_type_by_name(ctx: FunctionContext, name: str) -> Optional[Type]:
|
|
"""Return the type for the specific argument.
|
|
|
|
This helper should only be used with non-star arguments.
|
|
"""
|
|
if name not in ctx.callee_arg_names:
|
|
return None
|
|
idx = ctx.callee_arg_names.index(name)
|
|
arg_types = ctx.arg_types[idx]
|
|
if len(arg_types) != 1:
|
|
# Either an error or no value passed.
|
|
return None
|
|
return arg_types[0]
|