from collections import OrderedDict from typing import List, Optional, Sequence, Type, Union from django.core.exceptions import FieldError from django.db.models.base import Model from django.db.models.fields.related import RelatedField from mypy.nodes import Expression, NameExpr from mypy.plugin import FunctionContext, MethodContext from mypy.types import AnyType, Instance from mypy.types import Type as MypyType from mypy.types import TypeOfAny from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.lib import fullnames, helpers def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]: for base_type in [queryset_type, *queryset_type.type.bases]: if (len(base_type.args) and isinstance(base_type.args[0], Instance) and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)): return base_type.args[0] return None def determine_proper_manager_type(ctx: FunctionContext) -> MypyType: default_return_type = ctx.default_return_type assert isinstance(default_return_type, Instance) outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() if (outer_model_info is None or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)): return default_return_type return helpers.reparametrize_instance(default_return_type, [Instance(outer_model_info, [])]) def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], *, method: str, lookup: str) -> Optional[MypyType]: try: lookup_field = django_context.lookups_context.resolve_lookup(model_cls, lookup) except FieldError as exc: ctx.api.fail(exc.args[0], ctx.context) return None if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup: related_model_cls = django_context.fields_context.get_related_model_cls(lookup_field) lookup_field = django_context.get_primary_key_field(related_model_cls) field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx), lookup_field, method=method) return field_get_type def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], flat: bool, named: bool) -> MypyType: field_lookups = resolve_field_lookups(ctx.args[0], ctx, django_context) if field_lookups is None: return AnyType(TypeOfAny.from_error) typechecker_api = helpers.get_typechecker_api(ctx) if len(field_lookups) == 0: if flat: primary_key_field = django_context.get_primary_key_field(model_cls) lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls, lookup=primary_key_field.attname, method='values_list') assert lookup_type is not None return lookup_type elif named: column_types: 'OrderedDict[str, MypyType]' = OrderedDict() for field in django_context.get_model_fields(model_cls): column_type = django_context.fields_context.get_field_get_type(typechecker_api, field, method='values_list') column_types[field.attname] = column_type return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) else: # flat=False, named=False, all fields field_lookups = [] for field in django_context.get_model_fields(model_cls): field_lookups.append(field.attname) if len(field_lookups) > 1 and flat: typechecker_api.fail("'flat' is not valid when 'values_list' is called with more than one field", ctx.context) return AnyType(TypeOfAny.from_error) column_types = OrderedDict() for field_lookup in field_lookups: lookup_field_type = get_field_type_from_lookup(ctx, django_context, model_cls, lookup=field_lookup, method='values_list') if lookup_field_type is None: return AnyType(TypeOfAny.from_error) column_types[field_lookup] = lookup_field_type if flat: assert len(column_types) == 1 row_type = next(iter(column_types.values())) elif named: row_type = helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) else: row_type = helpers.make_tuple(typechecker_api, list(column_types.values())) return row_type def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: DjangoContext) -> MypyType: # called on the Instance, returns QuerySet of something assert isinstance(ctx.type, Instance) assert isinstance(ctx.default_return_type, Instance) model_type = _extract_model_type_from_queryset(ctx.type) if model_type is None: return AnyType(TypeOfAny.from_omitted_generics) model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname()) if model_cls is None: return ctx.default_return_type flat_expr = helpers.get_call_argument_by_name(ctx, 'flat') if flat_expr is not None and isinstance(flat_expr, NameExpr): flat = helpers.parse_bool(flat_expr) else: flat = False named_expr = helpers.get_call_argument_by_name(ctx, 'named') if named_expr is not None and isinstance(named_expr, NameExpr): named = helpers.parse_bool(named_expr) else: named = False if flat and named: ctx.api.fail("'flat' and 'named' can't be used together", ctx.context) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)]) # account for possible None flat = flat or False named = named or False row_type = get_values_list_row_type(ctx, django_context, model_cls, flat=flat, named=named) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) def resolve_field_lookups(lookup_exprs: Sequence[Expression], ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext) -> Optional[List[str]]: field_lookups = [] for field_lookup_expr in lookup_exprs: field_lookup = helpers.resolve_string_attribute_value(field_lookup_expr, ctx, django_context) if field_lookup is None: return None field_lookups.append(field_lookup) return field_lookups def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType: # called on QuerySet, return QuerySet of something assert isinstance(ctx.type, Instance) assert isinstance(ctx.default_return_type, Instance) model_type = _extract_model_type_from_queryset(ctx.type) if model_type is None: return AnyType(TypeOfAny.from_omitted_generics) model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname()) if model_cls is None: return ctx.default_return_type field_lookups = resolve_field_lookups(ctx.args[0], ctx, django_context) if field_lookups is None: return AnyType(TypeOfAny.from_error) if len(field_lookups) == 0: for field in django_context.get_model_fields(model_cls): field_lookups.append(field.attname) column_types: 'OrderedDict[str, MypyType]' = OrderedDict() for field_lookup in field_lookups: field_lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls, lookup=field_lookup, method='values') if field_lookup_type is None: return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)]) column_types[field_lookup] = field_lookup_type row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys())) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])