QuerySet.annotate improvements (#398)

* QuerySet.annotate returns self-type. Attribute access falls back to Any.

- QuerySets that have an annotated model do not report errors during .filter() when called with invalid fields.
- QuerySets that have an annotated model return ordinary dict rather than TypedDict for .values()
- QuerySets that have an annotated model return Any rather than typed Tuple for .values_list()

* Fix .annotate so it reuses existing annotated types. Fixes error in typechecking Django testsuite.

* Fix self-typecheck error

* Fix flake8

* Fix case of .values/.values_list before .annotate.

* Extra ignores for Django 2.2 tests (false positives due to tests assuming QuerySet.first() won't return None)

Fix mypy self-check.

* More tests + more precise typing in case annotate called before values_list.

Cleanup tests.

* Test and fix annotate in combination with values/values_list with no params.

* Remove line that does nothing :)

* Formatting fixes

* Address code review

* Fix quoting in tests after mypy changed things

* Use Final

* Use typing_extensions.Final

* Fixes after ValuesQuerySet -> _ValuesQuerySet refactor. Still not passing tests yet.

* Fix inheritance of _ValuesQuerySet and remove unneeded type ignores.

This allows the test
"annotate_values_or_values_list_before_or_after_annotate_broadens_type"
to pass.

* Make it possible to annotate user code with "annotated models", using PEP 583 Annotated type.

* Add docs

* Make QuerySet[_T] an external alias to _QuerySet[_T, _T].

This currently has the drawback that error messages display the internal type _QuerySet, with both type arguments.

See also discussion on #661 and #608.

Fixes #635: QuerySet methods on Managers (like .all()) now return QuerySets rather than Managers.

Address code review by @sobolevn.

* Support passing TypedDicts to WithAnnotations

* Add an example of an error to README regarding WithAnnotations + TypedDict.

* Fix runtime behavior of ValuesQuerySet alias (you can't extend Any, for example).

Fix some edge case with from_queryset after QuerySet changed to be an
alias to _QuerySet. Can't make a minimal test case as this only occurred
on a large internal codebase.

* Fix issue when using from_queryset in some cases when having an argument with a type annotation on the QuerySet.

The mypy docstring on anal_type says not to call defer() after it.
This commit is contained in:
Seth Yastrov
2021-07-23 15:15:15 +02:00
committed by GitHub
parent c69e720dd8
commit cfd69c0acc
25 changed files with 860 additions and 123 deletions

View File

@@ -1,18 +1,21 @@
from collections import OrderedDict
from typing import List, Optional, Sequence, Type
from typing import Dict, List, Optional, Sequence, Type
from django.core.exceptions import FieldError
from django.db.models.base import Model
from django.db.models.fields.related import RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel
from mypy.nodes import Expression, NameExpr
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, Expression, NameExpr
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance
from mypy.types import AnyType, Instance, TupleType
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny
from mypy.types import TypedDictType, TypeOfAny, get_proper_type
from mypy_django_plugin.django.context import DjangoContext, LookupsAreUnsupported
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import ANY_ATTR_ALLOWED_CLASS_FULLNAME
from mypy_django_plugin.lib.helpers import is_annotated_model_fullname
from mypy_django_plugin.transformers.models import get_or_create_annotated_type
def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]:
@@ -38,12 +41,19 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
def get_field_type_from_lookup(
ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], *, method: str, lookup: str
ctx: MethodContext,
django_context: DjangoContext,
model_cls: Type[Model],
*,
method: str,
lookup: str,
silent_on_error: bool = False,
) -> Optional[MypyType]:
try:
lookup_field = django_context.resolve_lookup_into_field(model_cls, lookup)
except FieldError as exc:
ctx.api.fail(exc.args[0], ctx.context)
if not silent_on_error:
ctx.api.fail(exc.args[0], ctx.context)
return None
except LookupsAreUnsupported:
return AnyType(TypeOfAny.explicit)
@@ -61,7 +71,13 @@ def get_field_type_from_lookup(
def get_values_list_row_type(
ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], flat: bool, named: bool
ctx: MethodContext,
django_context: DjangoContext,
model_cls: Type[Model],
*,
is_annotated: bool,
flat: bool,
named: bool,
) -> MypyType:
field_lookups = resolve_field_lookups(ctx.args[0], django_context)
if field_lookups is None:
@@ -81,9 +97,20 @@ def get_values_list_row_type(
for field in django_context.get_model_fields(model_cls):
column_type = django_context.get_field_get_type(typechecker_api, field, method="values_list")
column_types[field.attname] = column_type
return helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types)
if is_annotated:
# Return a NamedTuple with a fallback so that it's possible to access any field
return helpers.make_oneoff_named_tuple(
typechecker_api,
"Row",
column_types,
extra_bases=[typechecker_api.named_generic_type(ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])],
)
else:
return helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types)
else:
# flat=False, named=False, all fields
if is_annotated:
return typechecker_api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.special_form)])
field_lookups = []
for field in django_context.get_model_fields(model_cls):
field_lookups.append(field.attname)
@@ -95,10 +122,13 @@ def get_values_list_row_type(
column_types = OrderedDict()
for field_lookup in field_lookups:
lookup_field_type = get_field_type_from_lookup(
ctx, django_context, model_cls, lookup=field_lookup, method="values_list"
ctx, django_context, model_cls, lookup=field_lookup, method="values_list", silent_on_error=is_annotated
)
if lookup_field_type is None:
return AnyType(TypeOfAny.from_error)
if is_annotated:
lookup_field_type = AnyType(TypeOfAny.from_omitted_generics)
else:
return AnyType(TypeOfAny.from_error)
column_types[field_lookup] = lookup_field_type
if flat:
@@ -115,7 +145,8 @@ def get_values_list_row_type(
def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
# called on the Instance, returns QuerySet of something
assert isinstance(ctx.type, Instance)
assert isinstance(ctx.default_return_type, Instance)
default_return_type = get_proper_type(ctx.default_return_type)
assert isinstance(default_return_type, Instance)
model_type = _extract_model_type_from_queryset(ctx.type)
if model_type is None:
@@ -123,7 +154,7 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname)
if model_cls is None:
return ctx.default_return_type
return default_return_type
flat_expr = helpers.get_call_argument_by_name(ctx, "flat")
if flat_expr is not None and isinstance(flat_expr, NameExpr):
@@ -139,14 +170,89 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
if flat and named:
ctx.api.fail("'flat' and 'named' can't be used together", ctx.context)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
return helpers.reparametrize_instance(default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
# account for possible None
flat = flat or False
named = named or False
row_type = get_values_list_row_type(ctx, django_context, model_cls, flat=flat, named=named)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
is_annotated = is_annotated_model_fullname(model_type.type.fullname)
row_type = get_values_list_row_type(
ctx, django_context, model_cls, is_annotated=is_annotated, flat=flat, named=named
)
return helpers.reparametrize_instance(default_return_type, [model_type, row_type])
def gather_kwargs(ctx: MethodContext) -> Optional[Dict[str, MypyType]]:
num_args = len(ctx.arg_kinds)
kwargs = {}
named = (ARG_NAMED, ARG_NAMED_OPT)
for i in range(num_args):
if not ctx.arg_kinds[i]:
continue
if any(kind not in named for kind in ctx.arg_kinds[i]):
# Only named arguments supported
return None
for j in range(len(ctx.arg_names[i])):
name = ctx.arg_names[i][j]
assert name is not None
kwargs[name] = ctx.arg_types[i][j]
return kwargs
def extract_proper_type_queryset_annotate(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
# called on the Instance, returns QuerySet of something
assert isinstance(ctx.type, Instance)
default_return_type = get_proper_type(ctx.default_return_type)
assert isinstance(default_return_type, Instance)
model_type = _extract_model_type_from_queryset(ctx.type)
if model_type is None:
return AnyType(TypeOfAny.from_omitted_generics)
api = ctx.api
field_types = model_type.type.metadata.get("annotated_field_types")
kwargs = gather_kwargs(ctx)
if kwargs:
# For now, we don't try to resolve the output_field of the field would be, but use Any.
added_field_types = {name: AnyType(TypeOfAny.implementation_artifact) for name, typ in kwargs.items()}
if field_types is not None:
# Annotate was called more than once, so add/update existing field types
field_types.update(added_field_types)
else:
field_types = added_field_types
fields_dict = None
if field_types is not None:
fields_dict = helpers.make_typeddict(
api, fields=OrderedDict(field_types), required_keys=set(field_types.keys())
)
annotated_type = get_or_create_annotated_type(api, model_type, fields_dict=fields_dict)
row_type: MypyType
if len(default_return_type.args) > 1:
original_row_type: MypyType = default_return_type.args[1]
row_type = original_row_type
if isinstance(original_row_type, TypedDictType):
row_type = api.named_generic_type(
"builtins.dict", [api.named_generic_type("builtins.str", []), AnyType(TypeOfAny.from_omitted_generics)]
)
elif isinstance(original_row_type, TupleType):
fallback: Instance = original_row_type.partial_fallback
if fallback is not None and fallback.type.has_base("typing.NamedTuple"):
# TODO: Use a NamedTuple which contains the known fields, but also
# falls back to allowing any attribute access.
row_type = AnyType(TypeOfAny.implementation_artifact)
else:
row_type = api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.from_omitted_generics)])
elif isinstance(original_row_type, Instance) and original_row_type.type.has_base(
fullnames.MODEL_CLASS_FULLNAME
):
row_type = annotated_type
else:
row_type = annotated_type
return helpers.reparametrize_instance(default_return_type, [annotated_type, row_type])
def resolve_field_lookups(lookup_exprs: Sequence[Expression], django_context: DjangoContext) -> Optional[List[str]]:
@@ -162,7 +268,8 @@ def resolve_field_lookups(lookup_exprs: Sequence[Expression], django_context: Dj
def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
# called on QuerySet, return QuerySet of something
assert isinstance(ctx.type, Instance)
assert isinstance(ctx.default_return_type, Instance)
default_return_type = get_proper_type(ctx.default_return_type)
assert isinstance(default_return_type, Instance)
model_type = _extract_model_type_from_queryset(ctx.type)
if model_type is None:
@@ -170,7 +277,10 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname)
if model_cls is None:
return ctx.default_return_type
return default_return_type
if is_annotated_model_fullname(model_type.type.fullname):
return default_return_type
field_lookups = resolve_field_lookups(ctx.args[0], django_context)
if field_lookups is None:
@@ -186,9 +296,9 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
ctx, django_context, model_cls, lookup=field_lookup, method="values"
)
if field_lookup_type is None:
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
return helpers.reparametrize_instance(default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
column_types[field_lookup] = field_lookup_type
row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys()))
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
return helpers.reparametrize_instance(default_return_type, [model_type, row_type])