mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-06 12:14:28 +08:00
finish strict_optional support, enable it for typechecking of django tests
This commit is contained in:
@@ -2,13 +2,14 @@ import typing
|
||||
from typing import Dict, Optional
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import AssignmentStmt, ClassDef, Expression, FuncDef, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, \
|
||||
from mypy.nodes import AssignmentStmt, ClassDef, Expression, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, \
|
||||
TypeInfo
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import AnyType, CallableType, Instance, NoneTyp, Type, TypeOfAny, TypeVarType, UnionType
|
||||
from mypy.types import AnyType, Instance, NoneTyp, Type, TypeOfAny, TypeVarType, UnionType
|
||||
|
||||
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
|
||||
FIELD_FULLNAME = 'django.db.models.fields.Field'
|
||||
CHAR_FIELD_FULLNAME = 'django.db.models.fields.CharField'
|
||||
ARRAY_FIELD_FULLNAME = 'django.contrib.postgres.fields.array.ArrayField'
|
||||
AUTO_FIELD_FULLNAME = 'django.db.models.fields.AutoField'
|
||||
GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey'
|
||||
@@ -263,9 +264,12 @@ def is_optional(typ: Type) -> bool:
|
||||
return any([isinstance(item, NoneTyp) for item in typ.items])
|
||||
|
||||
|
||||
|
||||
def has_any_of_bases(info: TypeInfo, bases: typing.Sequence[str]) -> bool:
|
||||
for base_fullname in bases:
|
||||
if info.has_base(base_fullname):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_none_expr(expr: Expression) -> bool:
|
||||
return isinstance(expr, NameExpr) and expr.fullname == 'builtins.None'
|
||||
|
||||
@@ -106,6 +106,9 @@ def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is
|
||||
def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
|
||||
default_return_type = cast(Instance, ctx.default_return_type)
|
||||
is_nullable = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'null'))
|
||||
if not is_nullable and default_return_type.type.has_base(helpers.CHAR_FIELD_FULLNAME):
|
||||
# blank=True for CharField can be interpreted as null=True
|
||||
is_nullable = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'blank'))
|
||||
|
||||
set_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_set_type',
|
||||
is_nullable=is_nullable)
|
||||
@@ -197,3 +200,8 @@ def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> None
|
||||
if blank_arg:
|
||||
is_blankable = helpers.parse_bool(blank_arg)
|
||||
fields_metadata[field_name]['blank'] = is_blankable
|
||||
|
||||
# default
|
||||
default_arg = helpers.get_argument_by_name(ctx, 'default')
|
||||
if default_arg and not helpers.is_none_expr(default_arg):
|
||||
fields_metadata[field_name]['default_specified'] = True
|
||||
|
||||
@@ -25,12 +25,13 @@ def redefine_and_typecheck_model_init(ctx: FunctionContext) -> Type:
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
model: TypeInfo = ctx.default_return_type.type
|
||||
|
||||
expected_types = extract_expected_types(ctx, model)
|
||||
# order is preserved, can use for positionals
|
||||
expected_types = extract_expected_types(ctx, model, is_init=True)
|
||||
|
||||
# order is preserved, can be used for positionals
|
||||
positional_names = list(expected_types.keys())
|
||||
positional_names.remove('pk')
|
||||
visited_positionals = set()
|
||||
|
||||
visited_positionals = set()
|
||||
# check positionals
|
||||
for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
|
||||
actual_pos_name = positional_names[i]
|
||||
@@ -111,7 +112,8 @@ def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
|
||||
def extract_expected_types(ctx: FunctionContext, model: TypeInfo,
|
||||
is_init: bool = False) -> Dict[str, Type]:
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
|
||||
expected_types: Dict[str, Type] = {}
|
||||
@@ -119,7 +121,11 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, T
|
||||
if not primary_key_type:
|
||||
# no explicit primary key, set pk to Any and add id
|
||||
primary_key_type = AnyType(TypeOfAny.special_form)
|
||||
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
|
||||
if is_init:
|
||||
expected_types['id'] = helpers.make_optional(ctx.api.named_generic_type('builtins.int', []))
|
||||
else:
|
||||
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
|
||||
|
||||
expected_types['pk'] = primary_key_type
|
||||
|
||||
for base in model.mro:
|
||||
@@ -141,8 +147,9 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, T
|
||||
if field_type is None:
|
||||
continue
|
||||
|
||||
if typ.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
|
||||
primary_key_type = AnyType(TypeOfAny.implementation_artifact)
|
||||
if helpers.has_any_of_bases(typ.type, (helpers.FOREIGN_KEY_FULLNAME,
|
||||
helpers.ONETOONE_FIELD_FULLNAME)):
|
||||
related_primary_key_type = AnyType(TypeOfAny.implementation_artifact)
|
||||
# in case it's optional, we need Instance type
|
||||
referred_to_model = typ.args[1]
|
||||
is_nullable = helpers.is_optional(referred_to_model)
|
||||
@@ -156,11 +163,24 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, T
|
||||
autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField')
|
||||
pk_type = get_private_descriptor_type(autofield_info, '_pyi_private_set_type',
|
||||
is_nullable=is_nullable)
|
||||
primary_key_type = pk_type
|
||||
related_primary_key_type = pk_type
|
||||
|
||||
expected_types[name + '_id'] = primary_key_type
|
||||
if is_init:
|
||||
related_primary_key_type = helpers.make_optional(related_primary_key_type)
|
||||
|
||||
expected_types[name + '_id'] = related_primary_key_type
|
||||
|
||||
field_metadata = get_fields_metadata(model).get(name, {})
|
||||
if field_type:
|
||||
# related fields could be None in __init__ (but should be specified before save())
|
||||
if helpers.has_any_of_bases(typ.type, (helpers.FOREIGN_KEY_FULLNAME,
|
||||
helpers.ONETOONE_FIELD_FULLNAME)) and is_init:
|
||||
field_type = helpers.make_optional(field_type)
|
||||
|
||||
# if primary_key=True and default specified
|
||||
elif field_metadata.get('primary_key', False) and field_metadata.get('default_specified', False):
|
||||
field_type = helpers.make_optional(field_type)
|
||||
|
||||
expected_types[name] = field_type
|
||||
|
||||
return expected_types
|
||||
|
||||
Reference in New Issue
Block a user