mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-08 21:14:49 +08:00
add BaseManager.create() typechecking
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
import typing
|
||||
from typing import Dict, Optional
|
||||
|
||||
from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var, AssignmentStmt, \
|
||||
CallExpr
|
||||
from mypy.nodes import Expression, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType
|
||||
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType
|
||||
|
||||
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
|
||||
FIELD_FULLNAME = 'django.db.models.fields.Field'
|
||||
@@ -119,74 +118,6 @@ def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance:
|
||||
return reparametrize_with(type_to_fill, typevar_values)
|
||||
|
||||
|
||||
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
|
||||
if tp.type.has_base(FIELD_FULLNAME):
|
||||
set_method = tp.type.get_method('__set__')
|
||||
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
|
||||
if 'value' in set_method.type.arg_names:
|
||||
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
|
||||
if isinstance(set_value_type, Instance):
|
||||
set_value_type = fill_typevars(tp, set_value_type)
|
||||
return set_value_type
|
||||
elif isinstance(set_value_type, UnionType):
|
||||
items_no_typevars = []
|
||||
for item in set_value_type.items:
|
||||
if isinstance(item, Instance):
|
||||
item = fill_typevars(tp, item)
|
||||
items_no_typevars.append(item)
|
||||
return UnionType(items_no_typevars)
|
||||
|
||||
get_method = tp.type.get_method('__get__')
|
||||
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
|
||||
return get_method.type.ret_type
|
||||
# GenericForeignKey
|
||||
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
|
||||
return AnyType(TypeOfAny.special_form)
|
||||
return None
|
||||
|
||||
|
||||
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
|
||||
# only primary keys defined in current class for now
|
||||
for stmt in model.defn.defs.body:
|
||||
if isinstance(stmt, AssignmentStmt) and isinstance(stmt.rvalue, CallExpr):
|
||||
name_expr = stmt.lvalues[0]
|
||||
if isinstance(name_expr, NameExpr):
|
||||
name = name_expr.name
|
||||
if 'primary_key' in stmt.rvalue.arg_names:
|
||||
is_primary_key = stmt.rvalue.args[stmt.rvalue.arg_names.index('primary_key')]
|
||||
if is_primary_key:
|
||||
return extract_field_setter_type(model.names[name].type)
|
||||
return None
|
||||
|
||||
|
||||
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
|
||||
expected_types: Dict[str, Type] = {}
|
||||
|
||||
primary_key_type = extract_primary_key_type(model)
|
||||
if not primary_key_type:
|
||||
# no explicit primary key, set pk to Any and add id
|
||||
primary_key_type = AnyType(TypeOfAny.special_form)
|
||||
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
|
||||
|
||||
expected_types['pk'] = primary_key_type
|
||||
|
||||
for base in model.mro:
|
||||
for name, sym in base.names.items():
|
||||
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
|
||||
tp = sym.node.type
|
||||
field_type = extract_field_setter_type(tp)
|
||||
if tp.type.fullname() in {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME}:
|
||||
ref_to_model = tp.args[0]
|
||||
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME):
|
||||
primary_key_type = extract_primary_key_type(ref_to_model.type)
|
||||
if not primary_key_type:
|
||||
primary_key_type = AnyType(TypeOfAny.special_form)
|
||||
expected_types[name + '_id'] = primary_key_type
|
||||
if field_type:
|
||||
expected_types[name] = field_type
|
||||
return expected_types
|
||||
|
||||
|
||||
def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]:
|
||||
"""Return the expression for the specific argument.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user