Instead of using Literal types, overload QuerySet.values_list in the plugin. Fixes #43. (#44)

- Add a couple of extra type checks that Django makes:
  1) 'flat' and 'named' can't be used together.
  2) 'flat' is not valid when values_list is called with more than one field.
This commit is contained in:
Seth Yastrov
2019-03-13 20:10:37 +01:00
committed by Maxim Kurnikov
parent 7c57143310
commit b1a04d2f7d
3 changed files with 41 additions and 11 deletions

View File

@@ -22,7 +22,6 @@ from typing import (
from django.db.models.base import Model
from django.db.models.expressions import Combinable as Combinable, F as F
from django.db.models.sql.query import Query, RawQuery
from typing_extensions import Literal
from django.db import models
from django.db.models import Manager
@@ -99,18 +98,10 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
self, raw_query: str, params: Any = ..., translations: Optional[Dict[str, str]] = ..., using: None = ...
) -> RawQuerySet: ...
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> QuerySet[_T, Dict[str, Any]]: ...
@overload
# The type of values_list is overridden to be more specific in the mypy django plugin
def values_list(
self, *fields: Union[str, Combinable], flat: Literal[False] = ..., named: Literal[True]
) -> QuerySet[_T, NamedTuple]: ...
@overload
def values_list(
self, *fields: Union[str, Combinable], flat: Literal[True], named: Literal[False] = ...
self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...
) -> QuerySet[_T, Any]: ...
@overload
def values_list(
self, *fields: Union[str, Combinable], flat: Literal[False] = ..., named: Literal[False] = ...
) -> QuerySet[_T, Tuple]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet[_T, datetime.date]: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...

View File

@@ -229,6 +229,38 @@ def extract_proper_type_for_get_form(ctx: MethodContext) -> Type:
return ctx.default_return_type
def extract_proper_type_for_values_list(ctx: MethodContext) -> Type:
object_type = ctx.type
if not isinstance(object_type, Instance):
return ctx.default_return_type
flat = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'flat'))
named = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'named'))
ret = ctx.default_return_type
any_type = AnyType(TypeOfAny.implementation_artifact)
if named and flat:
ctx.api.fail("'flat' and 'named' can't be used together.", ctx.context)
return ret
elif named:
# TODO: Fill in namedtuple fields/types
row_arg = ctx.api.named_generic_type('typing.NamedTuple', [])
elif flat:
# TODO: Figure out row_arg type dependent on the argument passed in
if len(ctx.args[0]) > 1:
ctx.api.fail("'flat' is not valid when values_list is called with more than one field.", ctx.context)
return ret
row_arg = any_type
else:
# TODO: Figure out tuple argument types dependent on the arguments passed in
row_arg = ctx.api.named_generic_type('builtins.tuple', [any_type])
first_arg = ret.args[0] if len(ret.args) > 0 else any_type
new_type_args = [first_arg, row_arg]
return helpers.reparametrize_instance(ret, new_type_args)
class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
@@ -318,6 +350,11 @@ class DjangoPlugin(Plugin):
if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME):
return extract_proper_type_for_get_form
if method_name == 'values_list':
sym = self.lookup_fully_qualified(class_name)
if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.QUERYSET_CLASS_FULLNAME):
return extract_proper_type_for_values_list
if fullname in {'django.apps.registry.Apps.get_model',
'django.db.migrations.state.StateApps.get_model'}:
return determine_model_cls_from_string_for_migrations

View File

@@ -13,6 +13,8 @@ class BlogQuerySet(models.QuerySet[Blog]):
blog_qs: models.QuerySet[Blog]
reveal_type(blog_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog, main.Blog]'
Blog.objects.values_list('id', flat=True, named=True) # E: 'flat' and 'named' can't be used together.
Blog.objects.values_list('id', 'extra_arg', flat=True) # E: 'flat' is not valid when values_list is called with more than one field.
reveal_type(Blog.objects.in_bulk([1])) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'
reveal_type(Blog.objects.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'