mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-07 04:34:29 +08:00
192 lines
8.2 KiB
Python
192 lines
8.2 KiB
Python
from collections import OrderedDict
|
|
from typing import List, Optional, Sequence, Type, Union, cast
|
|
|
|
from django.core.exceptions import FieldError
|
|
from django.db.models.base import Model
|
|
from mypy.newsemanal.typeanal import TypeAnalyser
|
|
from mypy.nodes import Expression, NameExpr, TypeInfo
|
|
from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext
|
|
from mypy.types import AnyType, Instance
|
|
from mypy.types import Type as MypyType
|
|
from mypy.types import TypeOfAny
|
|
|
|
from mypy_django_plugin.django.context import DjangoContext
|
|
from mypy_django_plugin.lib import fullnames, helpers
|
|
|
|
|
|
def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
|
|
default_return_type = ctx.default_return_type
|
|
assert isinstance(default_return_type, Instance)
|
|
|
|
outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
|
|
if (outer_model_info is None
|
|
or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)):
|
|
return default_return_type
|
|
|
|
return helpers.reparametrize_instance(default_return_type, [Instance(outer_model_info, [])])
|
|
|
|
|
|
def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
|
|
*, method: str, lookup: str) -> Optional[MypyType]:
|
|
try:
|
|
lookup_field = django_context.lookups_context.resolve_lookup(model_cls, lookup)
|
|
except FieldError as exc:
|
|
ctx.api.fail(exc.args[0], ctx.context)
|
|
return None
|
|
|
|
field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx),
|
|
lookup_field, method=method)
|
|
return field_get_type
|
|
|
|
|
|
def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
|
|
flat: bool, named: bool) -> MypyType:
|
|
field_lookups = resolve_field_lookups(ctx.args[0], ctx, django_context)
|
|
if field_lookups is None:
|
|
return AnyType(TypeOfAny.from_error)
|
|
|
|
typechecker_api = helpers.get_typechecker_api(ctx)
|
|
if len(field_lookups) == 0:
|
|
if flat:
|
|
primary_key_field = django_context.get_primary_key_field(model_cls)
|
|
lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls,
|
|
lookup=primary_key_field.attname, method='values_list')
|
|
assert lookup_type is not None
|
|
return lookup_type
|
|
elif named:
|
|
column_types: 'OrderedDict[str, MypyType]' = OrderedDict()
|
|
for field in django_context.get_model_fields(model_cls):
|
|
column_type = django_context.fields_context.get_field_get_type(typechecker_api, field,
|
|
method='values_list')
|
|
column_types[field.attname] = column_type
|
|
return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
|
|
else:
|
|
# flat=False, named=False, all fields
|
|
field_lookups = []
|
|
for field in django_context.get_model_fields(model_cls):
|
|
field_lookups.append(field.attname)
|
|
|
|
if len(field_lookups) > 1 and flat:
|
|
typechecker_api.fail("'flat' is not valid when 'values_list' is called with more than one field", ctx.context)
|
|
return AnyType(TypeOfAny.from_error)
|
|
|
|
column_types = OrderedDict()
|
|
for field_lookup in field_lookups:
|
|
lookup_field_type = get_field_type_from_lookup(ctx, django_context, model_cls,
|
|
lookup=field_lookup, method='values_list')
|
|
if lookup_field_type is None:
|
|
return AnyType(TypeOfAny.from_error)
|
|
column_types[field_lookup] = lookup_field_type
|
|
|
|
if flat:
|
|
assert len(column_types) == 1
|
|
row_type = next(iter(column_types.values()))
|
|
elif named:
|
|
row_type = helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
|
|
else:
|
|
row_type = helpers.make_tuple(typechecker_api, list(column_types.values()))
|
|
|
|
return row_type
|
|
|
|
|
|
def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
|
# called on the Instance, returns QuerySet of something
|
|
assert isinstance(ctx.type, Instance)
|
|
assert isinstance(ctx.default_return_type, Instance)
|
|
|
|
# bail if queryset of Any or other non-instances
|
|
if not isinstance(ctx.type.args[0], Instance):
|
|
return AnyType(TypeOfAny.from_omitted_generics)
|
|
|
|
model_type = ctx.type.args[0]
|
|
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
|
if model_cls is None:
|
|
return ctx.default_return_type
|
|
|
|
flat_expr = helpers.get_call_argument_by_name(ctx, 'flat')
|
|
if flat_expr is not None and isinstance(flat_expr, NameExpr):
|
|
flat = helpers.parse_bool(flat_expr)
|
|
else:
|
|
flat = False
|
|
|
|
named_expr = helpers.get_call_argument_by_name(ctx, 'named')
|
|
if named_expr is not None and isinstance(named_expr, NameExpr):
|
|
named = helpers.parse_bool(named_expr)
|
|
else:
|
|
named = False
|
|
|
|
if flat and named:
|
|
ctx.api.fail("'flat' and 'named' can't be used together", ctx.context)
|
|
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
|
|
|
|
# account for possible None
|
|
flat = flat or False
|
|
named = named or False
|
|
|
|
row_type = get_values_list_row_type(ctx, django_context, model_cls,
|
|
flat=flat, named=named)
|
|
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
|
|
|
|
|
|
def resolve_field_lookups(lookup_exprs: Sequence[Expression], ctx: Union[FunctionContext, MethodContext],
|
|
django_context: DjangoContext) -> Optional[List[str]]:
|
|
field_lookups = []
|
|
for field_lookup_expr in lookup_exprs:
|
|
field_lookup = helpers.resolve_string_attribute_value(field_lookup_expr, ctx, django_context)
|
|
if field_lookup is None:
|
|
return None
|
|
field_lookups.append(field_lookup)
|
|
return field_lookups
|
|
|
|
|
|
def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
|
# called on QuerySet, return QuerySet of something
|
|
assert isinstance(ctx.type, Instance)
|
|
assert isinstance(ctx.default_return_type, Instance)
|
|
|
|
# if queryset of non-instance type
|
|
if not isinstance(ctx.type.args[0], Instance):
|
|
return AnyType(TypeOfAny.from_omitted_generics)
|
|
|
|
model_type = ctx.type.args[0]
|
|
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
|
if model_cls is None:
|
|
return ctx.default_return_type
|
|
|
|
field_lookups = resolve_field_lookups(ctx.args[0], ctx, django_context)
|
|
if field_lookups is None:
|
|
return AnyType(TypeOfAny.from_error)
|
|
|
|
if len(field_lookups) == 0:
|
|
for field in django_context.get_model_fields(model_cls):
|
|
field_lookups.append(field.attname)
|
|
|
|
column_types: 'OrderedDict[str, MypyType]' = OrderedDict()
|
|
for field_lookup in field_lookups:
|
|
field_lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls,
|
|
lookup=field_lookup, method='values')
|
|
if field_lookup_type is None:
|
|
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
|
|
|
|
column_types[field_lookup] = field_lookup_type
|
|
|
|
row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys()))
|
|
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
|
|
|
|
|
|
def set_first_generic_param_as_default_for_second(ctx: AnalyzeTypeContext, fullname: str) -> MypyType:
|
|
type_analyser_api = cast(TypeAnalyser, ctx.api)
|
|
|
|
info = helpers.lookup_fully_qualified_typeinfo(type_analyser_api.api, fullname) # type: ignore
|
|
assert isinstance(info, TypeInfo)
|
|
|
|
if not ctx.type.args:
|
|
return Instance(info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)])
|
|
|
|
args = ctx.type.args
|
|
if len(args) == 1:
|
|
args = [args[0], args[0]]
|
|
|
|
analyzed_args = [type_analyser_api.analyze_type(arg) for arg in args]
|
|
return Instance(info, analyzed_args)
|