mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-07 04:34:29 +08:00
remove -semanal suffix
This commit is contained in:
166
mypy_django_plugin/transformers/querysets.py
Normal file
166
mypy_django_plugin/transformers/querysets.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.fields.related import ForeignKey
|
||||
from mypy.nodes import NameExpr
|
||||
from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext
|
||||
from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny
|
||||
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
|
||||
|
||||
def set_first_generic_param_as_default_for_second(ctx: AnalyzeTypeContext, fullname: str) -> MypyType:
|
||||
if not ctx.type.args:
|
||||
try:
|
||||
return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit),
|
||||
AnyType(TypeOfAny.explicit)])
|
||||
except KeyError:
|
||||
# really should never happen
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
args = ctx.type.args
|
||||
if len(args) == 1:
|
||||
args = [args[0], args[0]]
|
||||
|
||||
analyzed_args = [ctx.api.analyze_type(arg) for arg in args]
|
||||
ctx.api.analyze_type(ctx.type)
|
||||
try:
|
||||
return ctx.api.named_type(fullname, analyzed_args)
|
||||
except KeyError:
|
||||
return AnyType(TypeOfAny.explicit)
|
||||
|
||||
|
||||
def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
|
||||
ret = ctx.default_return_type
|
||||
assert isinstance(ret, Instance)
|
||||
|
||||
if not ctx.api.tscope.classes:
|
||||
# not in class
|
||||
return ret
|
||||
outer_model_info = ctx.api.tscope.classes[0]
|
||||
if not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
|
||||
return ret
|
||||
|
||||
return helpers.reparametrize_instance(ret, [Instance(outer_model_info, [])])
|
||||
|
||||
|
||||
def get_lookup_field_get_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
|
||||
lookup: str, method: str) -> Optional[Tuple[str, MypyType]]:
|
||||
try:
|
||||
lookup_field = django_context.lookups_context.resolve_lookup(model_cls, lookup)
|
||||
except FieldError as exc:
|
||||
ctx.api.fail(exc.args[0], ctx.context)
|
||||
return None
|
||||
|
||||
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, method)
|
||||
return lookup_field.attname, field_get_type
|
||||
|
||||
|
||||
def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
|
||||
flat: bool, named: bool) -> MypyType:
|
||||
field_lookups = [expr.value for expr in ctx.args[0]]
|
||||
if len(field_lookups) == 0:
|
||||
if flat:
|
||||
primary_key_field = django_context.get_primary_key_field(model_cls)
|
||||
_, field_get_type = get_lookup_field_get_type(ctx, django_context, model_cls,
|
||||
primary_key_field.attname, 'values_list')
|
||||
return field_get_type
|
||||
elif named:
|
||||
column_types = OrderedDict()
|
||||
for field in django_context.get_model_fields(model_cls):
|
||||
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, field, 'values_list')
|
||||
column_types[field.attname] = field_get_type
|
||||
return helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types)
|
||||
else:
|
||||
# flat=False, named=False, all fields
|
||||
field_lookups = []
|
||||
for field in model_cls._meta.get_fields():
|
||||
field_lookups.append(field.attname)
|
||||
|
||||
if len(field_lookups) > 1 and flat:
|
||||
ctx.api.fail("'flat' is not valid when 'values_list' is called with more than one field", ctx.context)
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
column_types = OrderedDict()
|
||||
for field_lookup in field_lookups:
|
||||
result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values_list')
|
||||
if result is None:
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
field_name, field_get_type = result
|
||||
column_types[field_name] = field_get_type
|
||||
|
||||
if flat:
|
||||
assert len(column_types) == 1
|
||||
row_type = next(iter(column_types.values()))
|
||||
elif named:
|
||||
row_type = helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types)
|
||||
else:
|
||||
row_type = helpers.make_tuple(ctx.api, list(column_types.values()))
|
||||
|
||||
return row_type
|
||||
|
||||
|
||||
def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||
assert isinstance(ctx.type, Instance)
|
||||
assert isinstance(ctx.type.args[0], Instance)
|
||||
|
||||
model_type = ctx.type.args[0]
|
||||
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
flat_expr = helpers.get_call_argument_by_name(ctx, 'flat')
|
||||
if flat_expr is not None and isinstance(flat_expr, NameExpr):
|
||||
flat = helpers.parse_bool(flat_expr)
|
||||
else:
|
||||
flat = False
|
||||
|
||||
named_expr = helpers.get_call_argument_by_name(ctx, 'named')
|
||||
if named_expr is not None and isinstance(named_expr, NameExpr):
|
||||
named = helpers.parse_bool(named_expr)
|
||||
else:
|
||||
named = False
|
||||
|
||||
if flat and named:
|
||||
ctx.api.fail("'flat' and 'named' can't be used together", ctx.context)
|
||||
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
|
||||
|
||||
row_type = get_values_list_row_type(ctx, django_context, model_cls,
|
||||
flat=flat, named=named)
|
||||
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
|
||||
|
||||
|
||||
def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||
assert isinstance(ctx.type, Instance)
|
||||
assert isinstance(ctx.type.args[0], Instance)
|
||||
|
||||
model_type = ctx.type.args[0]
|
||||
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
field_lookups = [expr.value for expr in ctx.args[0]]
|
||||
if len(field_lookups) == 0:
|
||||
for field in model_cls._meta.get_fields():
|
||||
field_lookups.append(field.attname)
|
||||
|
||||
column_types = OrderedDict()
|
||||
for field_lookup in field_lookups:
|
||||
try:
|
||||
lookup_field = django_context.lookups_context.resolve_lookup(model_cls, field_lookup)
|
||||
except FieldError as exc:
|
||||
ctx.api.fail(exc.args[0], ctx.context)
|
||||
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
|
||||
|
||||
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, 'values')
|
||||
field_name = lookup_field.attname
|
||||
if isinstance(lookup_field, ForeignKey) and field_lookup == lookup_field.name:
|
||||
field_name = lookup_field.name
|
||||
|
||||
column_types[field_name] = field_get_type
|
||||
|
||||
row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys()))
|
||||
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
|
||||
Reference in New Issue
Block a user