QuerySet.annotate improvements (#398)

* QuerySet.annotate returns self-type. Attribute access falls back to Any.

- QuerySets that have an annotated model do not report errors during .filter() when called with invalid fields.
- QuerySets that have an annotated model return ordinary dict rather than TypedDict for .values()
- QuerySets that have an annotated model return Any rather than typed Tuple for .values_list()

* Fix .annotate so it reuses existing annotated types. Fixes error in typechecking Django testsuite.

* Fix self-typecheck error

* Fix flake8

* Fix case of .values/.values_list before .annotate.

* Extra ignores for Django 2.2 tests (false positives due to tests assuming QuerySet.first() won't return None)

Fix mypy self-check.

* More tests + more precise typing in case annotate called before values_list.

Cleanup tests.

* Test and fix annotate in combination with values/values_list with no params.

* Remove line that does nothing :)

* Formatting fixes

* Address code review

* Fix quoting in tests after mypy changed things

* Use Final

* Use typing_extensions.Final

* Fixes after ValuesQuerySet -> _ValuesQuerySet refactor. Still not passing tests yet.

* Fix inheritance of _ValuesQuerySet and remove unneeded type ignores.

This allows the test
"annotate_values_or_values_list_before_or_after_annotate_broadens_type"
to pass.

* Make it possible to annotate user code with "annotated models", using PEP 583 Annotated type.

* Add docs

* Make QuerySet[_T] an external alias to _QuerySet[_T, _T].

This currently has the drawback that error messages display the internal type _QuerySet, with both type arguments.

See also discussion on #661 and #608.

Fixes #635: QuerySet methods on Managers (like .all()) now return QuerySets rather than Managers.

Address code review by @sobolevn.

* Support passing TypedDicts to WithAnnotations

* Add an example of an error to README regarding WithAnnotations + TypedDict.

* Fix runtime behavior of ValuesQuerySet alias (you can't extend Any, for example).

Fix some edge case with from_queryset after QuerySet changed to be an
alias to _QuerySet. Can't make a minimal test case as this only occurred
on a large internal codebase.

* Fix issue when using from_queryset in some cases when having an argument with a type annotation on the QuerySet.

The mypy docstring on anal_type says not to call defer() after it.
This commit is contained in:
Seth Yastrov
2021-07-23 15:15:15 +02:00
committed by GitHub
parent c69e720dd8
commit cfd69c0acc
25 changed files with 860 additions and 123 deletions

View File

@@ -19,6 +19,16 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
return
assert isinstance(base_manager_info, TypeInfo)
passed_queryset = ctx.call.args[0]
assert isinstance(passed_queryset, NameExpr)
derived_queryset_fullname = passed_queryset.fullname
if derived_queryset_fullname is None:
# In some cases, due to the way the semantic analyzer works, only passed_queryset.name is available.
# But it should be analyzed again, so this isn't a problem.
return
new_manager_info = semanal_api.basic_new_typeinfo(
ctx.name, basetype_or_fallback=Instance(base_manager_info, [AnyType(TypeOfAny.unannotated)]), line=ctx.call.line
)
@@ -28,11 +38,6 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
current_module = semanal_api.cur_mod_node
current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True)
passed_queryset = ctx.call.args[0]
assert isinstance(passed_queryset, NameExpr)
derived_queryset_fullname = passed_queryset.fullname
assert derived_queryset_fullname is not None
sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname)
assert sym is not None

View File

@@ -1,19 +1,22 @@
from typing import Dict, List, Optional, Type, cast
from typing import Dict, List, Optional, Type, Union, cast
from django.db.models.base import Model
from django.db.models.fields import DateField, DateTimeField
from django.db.models.fields.related import ForeignKey
from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel
from mypy.checker import TypeChecker
from mypy.nodes import ARG_STAR2, Argument, Context, FuncDef, TypeInfo, Var
from mypy.plugin import AttributeContext, ClassDefContext
from mypy.plugin import AnalyzeTypeContext, AttributeContext, CheckerPluginInterface, ClassDefContext
from mypy.plugins import common
from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, Instance
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny
from mypy.types import TypedDictType, TypeOfAny
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME
from mypy_django_plugin.lib.helpers import add_new_class_for_module
from mypy_django_plugin.transformers import fields
from mypy_django_plugin.transformers.fields import get_field_descriptor_types
@@ -194,7 +197,6 @@ class AddManagers(ModelClassInitializer):
for manager_name, manager in model_cls._meta.managers_map.items():
manager_class_name = manager.__class__.__name__
manager_fullname = helpers.get_class_fullname(manager.__class__)
try:
manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname)
except helpers.IncompleteDefnException as exc:
@@ -390,3 +392,76 @@ def set_auth_user_model_boolean_fields(ctx: AttributeContext, django_context: Dj
boolinfo = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), bool)
assert boolinfo is not None
return Instance(boolinfo, [])
def handle_annotated_type(ctx: AnalyzeTypeContext, django_context: DjangoContext) -> MypyType:
args = ctx.type.args
type_arg = ctx.api.analyze_type(args[0])
api = cast(SemanticAnalyzer, ctx.api.api) # type: ignore
if not isinstance(type_arg, Instance):
return ctx.api.analyze_type(ctx.type)
fields_dict = None
if len(args) > 1:
second_arg_type = ctx.api.analyze_type(args[1])
if isinstance(second_arg_type, TypedDictType):
fields_dict = second_arg_type
elif isinstance(second_arg_type, Instance) and second_arg_type.type.fullname == ANNOTATIONS_FULLNAME:
annotations_type_arg = second_arg_type.args[0]
if isinstance(annotations_type_arg, TypedDictType):
fields_dict = annotations_type_arg
elif not isinstance(annotations_type_arg, AnyType):
ctx.api.fail("Only TypedDicts are supported as type arguments to Annotations", ctx.context)
return get_or_create_annotated_type(api, type_arg, fields_dict=fields_dict)
def get_or_create_annotated_type(
api: Union[SemanticAnalyzer, CheckerPluginInterface], model_type: Instance, fields_dict: Optional[TypedDictType]
) -> Instance:
"""
Get or create the type for a model for which you getting/setting any attr is allowed.
The generated type is an subclass of the model and django._AnyAttrAllowed.
The generated type is placed in the django_stubs_ext module, with the name WithAnnotations[ModelName].
If the user wanted to annotate their code using this type, then this is the annotation they would use.
This is a bit of a hack to make a pretty type for error messages and which would make sense for users.
"""
model_module_name = "django_stubs_ext"
if helpers.is_annotated_model_fullname(model_type.type.fullname):
# If it's already a generated class, we want to use the original model as a base
model_type = model_type.type.bases[0]
if fields_dict is not None:
type_name = f"WithAnnotations[{model_type.type.fullname}, {fields_dict}]"
else:
type_name = f"WithAnnotations[{model_type.type.fullname}]"
annotated_typeinfo = helpers.lookup_fully_qualified_typeinfo(
cast(TypeChecker, api), model_module_name + "." + type_name
)
if annotated_typeinfo is None:
model_module_file = api.modules[model_module_name] # type: ignore
if isinstance(api, SemanticAnalyzer):
annotated_model_type = api.named_type_or_none(ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])
assert annotated_model_type is not None
else:
annotated_model_type = api.named_generic_type(ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])
annotated_typeinfo = add_new_class_for_module(
model_module_file,
type_name,
bases=[model_type] if fields_dict is not None else [model_type, annotated_model_type],
fields=fields_dict.items if fields_dict is not None else None,
)
if fields_dict is not None:
# To allow structural subtyping, make it a Protocol
annotated_typeinfo.is_protocol = True
# Save for later to easily find which field types were annotated
annotated_typeinfo.metadata["annotated_field_types"] = fields_dict.items
annotated_type = Instance(annotated_typeinfo, [])
return annotated_type

View File

@@ -5,6 +5,7 @@ from mypy.types import TypeOfAny
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.helpers import is_annotated_model_fullname
def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
@@ -29,7 +30,11 @@ def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext)
):
provided_type = resolve_combinable_type(provided_type, django_context)
lookup_type = django_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg)
lookup_type: MypyType
if is_annotated_model_fullname(model_cls_fullname):
lookup_type = AnyType(TypeOfAny.implementation_artifact)
else:
lookup_type = django_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg)
# Managers as provided_type is not supported yet
if isinstance(provided_type, Instance) and helpers.has_any_of_bases(
provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME, fullnames.QUERYSET_CLASS_FULLNAME)

View File

@@ -1,18 +1,21 @@
from collections import OrderedDict
from typing import List, Optional, Sequence, Type
from typing import Dict, List, Optional, Sequence, Type
from django.core.exceptions import FieldError
from django.db.models.base import Model
from django.db.models.fields.related import RelatedField
from django.db.models.fields.reverse_related import ForeignObjectRel
from mypy.nodes import Expression, NameExpr
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, Expression, NameExpr
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance
from mypy.types import AnyType, Instance, TupleType
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny
from mypy.types import TypedDictType, TypeOfAny, get_proper_type
from mypy_django_plugin.django.context import DjangoContext, LookupsAreUnsupported
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import ANY_ATTR_ALLOWED_CLASS_FULLNAME
from mypy_django_plugin.lib.helpers import is_annotated_model_fullname
from mypy_django_plugin.transformers.models import get_or_create_annotated_type
def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]:
@@ -38,12 +41,19 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
def get_field_type_from_lookup(
ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], *, method: str, lookup: str
ctx: MethodContext,
django_context: DjangoContext,
model_cls: Type[Model],
*,
method: str,
lookup: str,
silent_on_error: bool = False,
) -> Optional[MypyType]:
try:
lookup_field = django_context.resolve_lookup_into_field(model_cls, lookup)
except FieldError as exc:
ctx.api.fail(exc.args[0], ctx.context)
if not silent_on_error:
ctx.api.fail(exc.args[0], ctx.context)
return None
except LookupsAreUnsupported:
return AnyType(TypeOfAny.explicit)
@@ -61,7 +71,13 @@ def get_field_type_from_lookup(
def get_values_list_row_type(
ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], flat: bool, named: bool
ctx: MethodContext,
django_context: DjangoContext,
model_cls: Type[Model],
*,
is_annotated: bool,
flat: bool,
named: bool,
) -> MypyType:
field_lookups = resolve_field_lookups(ctx.args[0], django_context)
if field_lookups is None:
@@ -81,9 +97,20 @@ def get_values_list_row_type(
for field in django_context.get_model_fields(model_cls):
column_type = django_context.get_field_get_type(typechecker_api, field, method="values_list")
column_types[field.attname] = column_type
return helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types)
if is_annotated:
# Return a NamedTuple with a fallback so that it's possible to access any field
return helpers.make_oneoff_named_tuple(
typechecker_api,
"Row",
column_types,
extra_bases=[typechecker_api.named_generic_type(ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])],
)
else:
return helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types)
else:
# flat=False, named=False, all fields
if is_annotated:
return typechecker_api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.special_form)])
field_lookups = []
for field in django_context.get_model_fields(model_cls):
field_lookups.append(field.attname)
@@ -95,10 +122,13 @@ def get_values_list_row_type(
column_types = OrderedDict()
for field_lookup in field_lookups:
lookup_field_type = get_field_type_from_lookup(
ctx, django_context, model_cls, lookup=field_lookup, method="values_list"
ctx, django_context, model_cls, lookup=field_lookup, method="values_list", silent_on_error=is_annotated
)
if lookup_field_type is None:
return AnyType(TypeOfAny.from_error)
if is_annotated:
lookup_field_type = AnyType(TypeOfAny.from_omitted_generics)
else:
return AnyType(TypeOfAny.from_error)
column_types[field_lookup] = lookup_field_type
if flat:
@@ -115,7 +145,8 @@ def get_values_list_row_type(
def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
# called on the Instance, returns QuerySet of something
assert isinstance(ctx.type, Instance)
assert isinstance(ctx.default_return_type, Instance)
default_return_type = get_proper_type(ctx.default_return_type)
assert isinstance(default_return_type, Instance)
model_type = _extract_model_type_from_queryset(ctx.type)
if model_type is None:
@@ -123,7 +154,7 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname)
if model_cls is None:
return ctx.default_return_type
return default_return_type
flat_expr = helpers.get_call_argument_by_name(ctx, "flat")
if flat_expr is not None and isinstance(flat_expr, NameExpr):
@@ -139,14 +170,89 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
if flat and named:
ctx.api.fail("'flat' and 'named' can't be used together", ctx.context)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
return helpers.reparametrize_instance(default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
# account for possible None
flat = flat or False
named = named or False
row_type = get_values_list_row_type(ctx, django_context, model_cls, flat=flat, named=named)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
is_annotated = is_annotated_model_fullname(model_type.type.fullname)
row_type = get_values_list_row_type(
ctx, django_context, model_cls, is_annotated=is_annotated, flat=flat, named=named
)
return helpers.reparametrize_instance(default_return_type, [model_type, row_type])
def gather_kwargs(ctx: MethodContext) -> Optional[Dict[str, MypyType]]:
num_args = len(ctx.arg_kinds)
kwargs = {}
named = (ARG_NAMED, ARG_NAMED_OPT)
for i in range(num_args):
if not ctx.arg_kinds[i]:
continue
if any(kind not in named for kind in ctx.arg_kinds[i]):
# Only named arguments supported
return None
for j in range(len(ctx.arg_names[i])):
name = ctx.arg_names[i][j]
assert name is not None
kwargs[name] = ctx.arg_types[i][j]
return kwargs
def extract_proper_type_queryset_annotate(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
# called on the Instance, returns QuerySet of something
assert isinstance(ctx.type, Instance)
default_return_type = get_proper_type(ctx.default_return_type)
assert isinstance(default_return_type, Instance)
model_type = _extract_model_type_from_queryset(ctx.type)
if model_type is None:
return AnyType(TypeOfAny.from_omitted_generics)
api = ctx.api
field_types = model_type.type.metadata.get("annotated_field_types")
kwargs = gather_kwargs(ctx)
if kwargs:
# For now, we don't try to resolve the output_field of the field would be, but use Any.
added_field_types = {name: AnyType(TypeOfAny.implementation_artifact) for name, typ in kwargs.items()}
if field_types is not None:
# Annotate was called more than once, so add/update existing field types
field_types.update(added_field_types)
else:
field_types = added_field_types
fields_dict = None
if field_types is not None:
fields_dict = helpers.make_typeddict(
api, fields=OrderedDict(field_types), required_keys=set(field_types.keys())
)
annotated_type = get_or_create_annotated_type(api, model_type, fields_dict=fields_dict)
row_type: MypyType
if len(default_return_type.args) > 1:
original_row_type: MypyType = default_return_type.args[1]
row_type = original_row_type
if isinstance(original_row_type, TypedDictType):
row_type = api.named_generic_type(
"builtins.dict", [api.named_generic_type("builtins.str", []), AnyType(TypeOfAny.from_omitted_generics)]
)
elif isinstance(original_row_type, TupleType):
fallback: Instance = original_row_type.partial_fallback
if fallback is not None and fallback.type.has_base("typing.NamedTuple"):
# TODO: Use a NamedTuple which contains the known fields, but also
# falls back to allowing any attribute access.
row_type = AnyType(TypeOfAny.implementation_artifact)
else:
row_type = api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.from_omitted_generics)])
elif isinstance(original_row_type, Instance) and original_row_type.type.has_base(
fullnames.MODEL_CLASS_FULLNAME
):
row_type = annotated_type
else:
row_type = annotated_type
return helpers.reparametrize_instance(default_return_type, [annotated_type, row_type])
def resolve_field_lookups(lookup_exprs: Sequence[Expression], django_context: DjangoContext) -> Optional[List[str]]:
@@ -162,7 +268,8 @@ def resolve_field_lookups(lookup_exprs: Sequence[Expression], django_context: Dj
def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
# called on QuerySet, return QuerySet of something
assert isinstance(ctx.type, Instance)
assert isinstance(ctx.default_return_type, Instance)
default_return_type = get_proper_type(ctx.default_return_type)
assert isinstance(default_return_type, Instance)
model_type = _extract_model_type_from_queryset(ctx.type)
if model_type is None:
@@ -170,7 +277,10 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname)
if model_cls is None:
return ctx.default_return_type
return default_return_type
if is_annotated_model_fullname(model_type.type.fullname):
return default_return_type
field_lookups = resolve_field_lookups(ctx.args[0], django_context)
if field_lookups is None:
@@ -186,9 +296,9 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
ctx, django_context, model_cls, lookup=field_lookup, method="values"
)
if field_lookup_type is None:
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
return helpers.reparametrize_instance(default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
column_types[field_lookup] = field_lookup_type
row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys()))
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
return helpers.reparametrize_instance(default_return_type, [model_type, row_type])