mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-06 20:24:31 +08:00
195 lines
8.9 KiB
Python
195 lines
8.9 KiB
Python
from typing import Dict, Optional, Set, cast
|
|
|
|
from mypy.checker import TypeChecker
|
|
from mypy.nodes import TypeInfo, Var
|
|
from mypy.plugin import FunctionContext, MethodContext
|
|
from mypy.types import AnyType, Instance, Type, TypeOfAny
|
|
|
|
from mypy_django_plugin.lib import metadata, fullnames, helpers
|
|
|
|
|
|
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
|
|
pointer_args: Set[str] = set()
|
|
for base in model.bases:
|
|
if base.type.has_base(fullnames.MODEL_CLASS_FULLNAME):
|
|
parent_name = base.type.name().lower()
|
|
pointer_args.add(f'{parent_name}_ptr')
|
|
pointer_args.add(f'{parent_name}_ptr_id')
|
|
return pointer_args
|
|
|
|
|
|
def redefine_and_typecheck_model_init(ctx: FunctionContext) -> Type:
|
|
assert isinstance(ctx.default_return_type, Instance)
|
|
|
|
api = cast(TypeChecker, ctx.api)
|
|
model: TypeInfo = ctx.default_return_type.type
|
|
|
|
expected_types = extract_expected_types(ctx, model, is_init=True)
|
|
|
|
# order is preserved, can be used for positionals
|
|
positional_names = list(expected_types.keys())
|
|
positional_names.remove('pk')
|
|
|
|
visited_positionals = set()
|
|
# check positionals
|
|
for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
|
|
actual_pos_name = positional_names[i]
|
|
api.check_subtype(actual_pos_type, expected_types[actual_pos_name],
|
|
ctx.context,
|
|
'Incompatible type for "{}" of "{}"'.format(actual_pos_name,
|
|
model.name()),
|
|
'got', 'expected')
|
|
visited_positionals.add(actual_pos_name)
|
|
|
|
# extract name of base models for _ptr
|
|
base_pointer_args = extract_base_pointer_args(model)
|
|
|
|
# check kwargs
|
|
for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])):
|
|
if actual_name in base_pointer_args:
|
|
# parent_ptr args are not supported
|
|
continue
|
|
if actual_name in visited_positionals:
|
|
continue
|
|
if actual_name is None:
|
|
# unpacked dict as kwargs is not supported
|
|
continue
|
|
if actual_name not in expected_types:
|
|
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
|
|
model.name()),
|
|
ctx.context)
|
|
continue
|
|
api.check_subtype(actual_type, expected_types[actual_name],
|
|
ctx.context,
|
|
'Incompatible type for "{}" of "{}"'.format(actual_name,
|
|
model.name()),
|
|
'got', 'expected')
|
|
return ctx.default_return_type
|
|
|
|
|
|
def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type:
|
|
api = cast(TypeChecker, ctx.api)
|
|
if isinstance(ctx.type, Instance) and len(ctx.type.args) > 0:
|
|
model_generic_arg = ctx.type.args[0]
|
|
else:
|
|
model_generic_arg = ctx.default_return_type
|
|
|
|
if isinstance(model_generic_arg, AnyType):
|
|
return ctx.default_return_type
|
|
|
|
model: TypeInfo = model_generic_arg.type
|
|
|
|
# extract name of base models for _ptr
|
|
base_pointer_args = extract_base_pointer_args(model)
|
|
expected_types = extract_expected_types(ctx, model)
|
|
|
|
for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]):
|
|
if actual_name in base_pointer_args:
|
|
# parent_ptr args are not supported
|
|
continue
|
|
if actual_name is None:
|
|
# unpacked dict as kwargs is not supported
|
|
continue
|
|
if actual_name not in expected_types:
|
|
api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
|
|
model.name()),
|
|
ctx.context)
|
|
continue
|
|
api.check_subtype(actual_type, expected_types[actual_name],
|
|
ctx.context,
|
|
'Incompatible type for "{}" of "{}"'.format(actual_name,
|
|
model.name()),
|
|
'got', 'expected')
|
|
|
|
return ctx.default_return_type
|
|
|
|
|
|
def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
|
|
field_metadata = metadata.get_fields_metadata(model).get(field_name, {})
|
|
if 'choices' in field_metadata:
|
|
return field_metadata['choices']
|
|
return None
|
|
|
|
|
|
def extract_expected_types(ctx: FunctionContext, model: TypeInfo,
|
|
is_init: bool = False) -> Dict[str, Type]:
|
|
api = cast(TypeChecker, ctx.api)
|
|
|
|
expected_types: Dict[str, Type] = {}
|
|
primary_key_type = helpers.extract_explicit_set_type_of_model_primary_key(model)
|
|
if not primary_key_type:
|
|
# no explicit primary key, set pk to Any and add id
|
|
primary_key_type = AnyType(TypeOfAny.special_form)
|
|
if is_init:
|
|
expected_types['id'] = helpers.make_optional(ctx.api.named_generic_type('builtins.int', []))
|
|
else:
|
|
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
|
|
|
|
expected_types['pk'] = primary_key_type
|
|
|
|
for base in model.mro:
|
|
# extract all fields for all models in MRO
|
|
for name, sym in base.names.items():
|
|
# do not redefine special attrs
|
|
if name in {'_meta', 'pk'}:
|
|
continue
|
|
|
|
if isinstance(sym.node, Var):
|
|
typ = sym.node.type
|
|
if typ is None or isinstance(typ, AnyType):
|
|
# types are not ready, fallback to Any
|
|
expected_types[name] = AnyType(TypeOfAny.from_unimported_type)
|
|
expected_types[name + '_id'] = AnyType(TypeOfAny.from_unimported_type)
|
|
|
|
elif isinstance(typ, Instance):
|
|
field_type = helpers.extract_field_setter_type(typ)
|
|
if field_type is None:
|
|
continue
|
|
|
|
if helpers.has_any_of_bases(typ.type, (fullnames.FOREIGN_KEY_FULLNAME,
|
|
fullnames.ONETOONE_FIELD_FULLNAME)):
|
|
related_primary_key_type = AnyType(TypeOfAny.implementation_artifact)
|
|
# in case it's optional, we need Instance type
|
|
referred_to_model = typ.args[1]
|
|
is_nullable = helpers.is_optional(referred_to_model)
|
|
if is_nullable:
|
|
referred_to_model = helpers.make_required(typ.args[1])
|
|
|
|
if isinstance(referred_to_model, Instance) and referred_to_model.type.has_base(
|
|
fullnames.MODEL_CLASS_FULLNAME):
|
|
pk_type = helpers.extract_explicit_set_type_of_model_primary_key(referred_to_model.type)
|
|
if not pk_type:
|
|
# extract set type of AutoField
|
|
autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField')
|
|
pk_type = helpers.get_private_descriptor_type(autofield_info, '_pyi_private_set_type',
|
|
is_nullable=is_nullable)
|
|
related_primary_key_type = pk_type
|
|
|
|
if is_init:
|
|
related_primary_key_type = helpers.make_optional(related_primary_key_type)
|
|
|
|
expected_types[name + '_id'] = related_primary_key_type
|
|
|
|
field_metadata = metadata.get_fields_metadata(model).get(name, {})
|
|
if field_type:
|
|
# related fields could be None in __init__ (but should be specified before save())
|
|
if helpers.has_any_of_bases(typ.type, (fullnames.FOREIGN_KEY_FULLNAME,
|
|
fullnames.ONETOONE_FIELD_FULLNAME)) and is_init:
|
|
field_type = helpers.make_optional(field_type)
|
|
|
|
# if primary_key=True and default specified
|
|
elif field_metadata.get('primary_key', False) and field_metadata.get('default_specified',
|
|
False):
|
|
field_type = helpers.make_optional(field_type)
|
|
|
|
# if CharField(blank=True,...) and not nullable, then field can be None in __init__
|
|
elif (
|
|
helpers.has_any_of_bases(typ.type, (fullnames.CHAR_FIELD_FULLNAME,)) and is_init and
|
|
field_metadata.get('blank', False) and not field_metadata.get('null', False)
|
|
):
|
|
field_type = helpers.make_optional(field_type)
|
|
|
|
expected_types[name] = field_type
|
|
|
|
return expected_types
|