diff --git a/django-stubs/db/models/query.pyi b/django-stubs/db/models/query.pyi index 0ac0632..3c1bddb 100644 --- a/django-stubs/db/models/query.pyi +++ b/django-stubs/db/models/query.pyi @@ -22,7 +22,6 @@ from typing import ( from django.db.models.base import Model from django.db.models.expressions import Combinable as Combinable, F as F from django.db.models.sql.query import Query, RawQuery -from typing_extensions import Literal from django.db import models from django.db.models import Manager @@ -99,18 +98,10 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized): self, raw_query: str, params: Any = ..., translations: Optional[Dict[str, str]] = ..., using: None = ... ) -> RawQuerySet: ... def values(self, *fields: Union[str, Combinable], **expressions: Any) -> QuerySet[_T, Dict[str, Any]]: ... - @overload + # The type of values_list is overridden to be more specific in the mypy django plugin def values_list( - self, *fields: Union[str, Combinable], flat: Literal[False] = ..., named: Literal[True] - ) -> QuerySet[_T, NamedTuple]: ... - @overload - def values_list( - self, *fields: Union[str, Combinable], flat: Literal[True], named: Literal[False] = ... + self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ... ) -> QuerySet[_T, Any]: ... - @overload - def values_list( - self, *fields: Union[str, Combinable], flat: Literal[False] = ..., named: Literal[False] = ... - ) -> QuerySet[_T, Tuple]: ... def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet[_T, datetime.date]: ... def datetimes( self, field_name: str, kind: str, order: str = ..., tzinfo: None = ... diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 991b018..8247e18 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -229,6 +229,38 @@ def extract_proper_type_for_get_form(ctx: MethodContext) -> Type: return ctx.default_return_type +def extract_proper_type_for_values_list(ctx: MethodContext) -> Type: + object_type = ctx.type + if not isinstance(object_type, Instance): + return ctx.default_return_type + + flat = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'flat')) + named = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'named')) + + ret = ctx.default_return_type + + any_type = AnyType(TypeOfAny.implementation_artifact) + if named and flat: + ctx.api.fail("'flat' and 'named' can't be used together.", ctx.context) + return ret + elif named: + # TODO: Fill in namedtuple fields/types + row_arg = ctx.api.named_generic_type('typing.NamedTuple', []) + elif flat: + # TODO: Figure out row_arg type dependent on the argument passed in + if len(ctx.args[0]) > 1: + ctx.api.fail("'flat' is not valid when values_list is called with more than one field.", ctx.context) + return ret + row_arg = any_type + else: + # TODO: Figure out tuple argument types dependent on the arguments passed in + row_arg = ctx.api.named_generic_type('builtins.tuple', [any_type]) + + first_arg = ret.args[0] if len(ret.args) > 0 else any_type + new_type_args = [first_arg, row_arg] + return helpers.reparametrize_instance(ret, new_type_args) + + class DjangoPlugin(Plugin): def __init__(self, options: Options) -> None: super().__init__(options) @@ -318,6 +350,11 @@ class DjangoPlugin(Plugin): if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): return extract_proper_type_for_get_form + if method_name == 'values_list': + sym = self.lookup_fully_qualified(class_name) + if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.QUERYSET_CLASS_FULLNAME): + return extract_proper_type_for_values_list + if fullname in {'django.apps.registry.Apps.get_model', 'django.db.migrations.state.StateApps.get_model'}: return determine_model_cls_from_string_for_migrations diff --git a/test-data/typecheck/queryset.test b/test-data/typecheck/queryset.test index 23fff7a..89926af 100644 --- a/test-data/typecheck/queryset.test +++ b/test-data/typecheck/queryset.test @@ -13,6 +13,8 @@ class BlogQuerySet(models.QuerySet[Blog]): blog_qs: models.QuerySet[Blog] reveal_type(blog_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog, main.Blog]' +Blog.objects.values_list('id', flat=True, named=True) # E: 'flat' and 'named' can't be used together. +Blog.objects.values_list('id', 'extra_arg', flat=True) # E: 'flat' is not valid when values_list is called with more than one field. reveal_type(Blog.objects.in_bulk([1])) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' reveal_type(Blog.objects.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'