mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-09 21:46:43 +08:00
Specific return types for values and values list (#53)
* Instead of using Literal types, overload QuerySet.values_list in the plugin. Fixes #43. - Add a couple of extra type checks that Django makes: 1) 'flat' and 'named' can't be used together. 2) 'flat' is not valid when values_list is called with more than one field. * Determine better row types for values_list/values based on fields specified. - In the case of values_list, we use a Row type with either a single primitive, Tuple, or NamedTuple. - In the case of values, we use a TypedDict. - In both cases, Any is used as a fallback for individual fields if those fields cannot be resolved. A couple other fixes I made along the way: - Don't create reverse relation for ForeignKeys with related_name='+' - Don't skip creating other related managers in AddRelatedManagers if a dynamic value is encountered for related_name parameter, or if the type cannot be determined. * Fix for TypedDict so that they are considered anonymous. * Clean up some comments. * Implement making TypedDict anonymous in a way that doesn't crash sometimes. * Fix flake8 errors. * Remove even uglier hack about making TypedDict anonymous. * Address review comments. Write a few better comments inside tests. * Fix crash when running with mypyc ("interpreted classes cannot inherit from compiled") due to the way I extended TypedDictType. - Implemented the hack in another way that works on mypyc. - Added a couple extra tests of accessing 'id' / 'pk' via values_list. * Fix flake8 errors. * Support annotation expressions (use type Any) for TypedDicts row types returned by values_list. - Bonus points: handle values_list gracefully (use type Any) where Tuples are returned where some of the fields arguments are not string literals.
This commit is contained in:
committed by
Maxim Kurnikov
parent
5c6be7ad12
commit
5b455b729a
@@ -1,10 +1,10 @@
|
||||
from typing import Optional, cast
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo, Var
|
||||
from mypy.nodes import ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import (
|
||||
AnyType, CallableType, Instance, TupleType, Type, TypeOfAny, UnionType,
|
||||
AnyType, CallableType, Instance, TupleType, Type, UnionType,
|
||||
)
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
@@ -88,23 +88,13 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext) -> Type:
|
||||
return helpers.reparametrize_instance(ctx.default_return_type, new_args=args)
|
||||
|
||||
|
||||
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> Type:
|
||||
node = type_info.get(private_field_name).node
|
||||
if isinstance(node, Var):
|
||||
descriptor_type = node.type
|
||||
if is_nullable:
|
||||
descriptor_type = helpers.make_optional(descriptor_type)
|
||||
return descriptor_type
|
||||
return AnyType(TypeOfAny.unannotated)
|
||||
|
||||
|
||||
def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
|
||||
default_return_type = cast(Instance, ctx.default_return_type)
|
||||
is_nullable = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'null'))
|
||||
set_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_set_type',
|
||||
is_nullable=is_nullable)
|
||||
get_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_get_type',
|
||||
is_nullable=is_nullable)
|
||||
set_type = helpers.get_private_descriptor_type(default_return_type.type, '_pyi_private_set_type',
|
||||
is_nullable=is_nullable)
|
||||
get_type = helpers.get_private_descriptor_type(default_return_type.type, '_pyi_private_get_type',
|
||||
is_nullable=is_nullable)
|
||||
return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from mypy.plugin import FunctionContext, MethodContext
|
||||
from mypy.types import AnyType, Instance, Type, TypeOfAny
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
from mypy_django_plugin.transformers.fields import get_private_descriptor_type
|
||||
|
||||
|
||||
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
|
||||
@@ -162,8 +161,8 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo,
|
||||
if not pk_type:
|
||||
# extract set type of AutoField
|
||||
autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField')
|
||||
pk_type = get_private_descriptor_type(autofield_info, '_pyi_private_set_type',
|
||||
is_nullable=is_nullable)
|
||||
pk_type = helpers.get_private_descriptor_type(autofield_info, '_pyi_private_set_type',
|
||||
is_nullable=is_nullable)
|
||||
related_primary_key_type = pk_type
|
||||
|
||||
if is_init:
|
||||
@@ -185,8 +184,8 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo,
|
||||
|
||||
# if CharField(blank=True,...) and not nullable, then field can be None in __init__
|
||||
elif (
|
||||
helpers.has_any_of_bases(typ.type, (helpers.CHAR_FIELD_FULLNAME,)) and is_init and
|
||||
field_metadata.get('blank', False) and not field_metadata.get('null', False)
|
||||
helpers.has_any_of_bases(typ.type, (helpers.CHAR_FIELD_FULLNAME,)) and is_init and
|
||||
field_metadata.get('blank', False) and not field_metadata.get('null', False)
|
||||
):
|
||||
field_type = helpers.make_optional(field_type)
|
||||
|
||||
|
||||
@@ -172,17 +172,35 @@ class AddRelatedManagers(ModelClassInitializer):
|
||||
ref_to_fullname = module_name + '.' + exc.model_cls_name
|
||||
|
||||
if self.model_classdef.fullname == ref_to_fullname:
|
||||
related_manager_name = defn.name.lower() + '_set'
|
||||
related_name = defn.name.lower() + '_set'
|
||||
if 'related_name' in rvalue.arg_names:
|
||||
related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')]
|
||||
if not isinstance(related_name_expr, StrExpr):
|
||||
return None
|
||||
related_manager_name = related_name_expr.value
|
||||
continue
|
||||
related_name = related_name_expr.value
|
||||
if related_name == '+':
|
||||
# No backwards relation is desired
|
||||
continue
|
||||
|
||||
if 'related_query_name' in rvalue.arg_names:
|
||||
related_query_name_expr = rvalue.args[rvalue.arg_names.index('related_query_name')]
|
||||
if not isinstance(related_query_name_expr, StrExpr):
|
||||
related_query_name = None
|
||||
else:
|
||||
related_query_name = related_query_name_expr.value
|
||||
# TODO: Handle defaulting to model name if related_name is not set
|
||||
else:
|
||||
# No related_query_name specified, default to related_name
|
||||
related_query_name = related_name
|
||||
typ = get_related_field_type(rvalue, self.api, defn.info)
|
||||
if typ is None:
|
||||
return None
|
||||
self.add_new_node_to_model_class(related_manager_name, typ)
|
||||
continue
|
||||
self.add_new_node_to_model_class(related_name, typ)
|
||||
if related_query_name is not None:
|
||||
# Only create related_query_name if it is a string literal
|
||||
helpers.get_lookups_metadata(self.model_classdef.info)[related_query_name] = {
|
||||
'related_query_name_target': related_name
|
||||
}
|
||||
|
||||
|
||||
def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]:
|
||||
|
||||
139
mypy_django_plugin/transformers/queryset.py
Normal file
139
mypy_django_plugin/transformers/queryset.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Union, List, cast, Optional
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import StrExpr, TypeInfo
|
||||
from mypy.plugin import MethodContext, CheckerPluginInterface
|
||||
from mypy.types import Type, Instance, AnyType, TypeOfAny
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
from mypy_django_plugin.lookups import resolve_lookup, RelatedModelNode, LookupException
|
||||
|
||||
|
||||
def extract_proper_type_for_values_and_values_list(method_name: str, ctx: MethodContext) -> Type:
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
|
||||
object_type = ctx.type
|
||||
if not isinstance(object_type, Instance):
|
||||
return ctx.default_return_type
|
||||
|
||||
ret = ctx.default_return_type
|
||||
|
||||
any_type = AnyType(TypeOfAny.implementation_artifact)
|
||||
fields_arg_expr = ctx.args[ctx.callee_arg_names.index('fields')]
|
||||
|
||||
model_arg: Union[AnyType, Type] = ret.args[0] if len(ret.args) > 0 else any_type
|
||||
|
||||
column_names: List[Optional[str]] = []
|
||||
column_types: OrderedDict[str, Type] = OrderedDict()
|
||||
|
||||
fill_column_types = True
|
||||
|
||||
if len(fields_arg_expr) == 0:
|
||||
# values_list/values with no args is not yet supported, so default to Any types for field types
|
||||
# It should in the future include all model fields, "extra" fields and "annotated" fields
|
||||
fill_column_types = False
|
||||
|
||||
if isinstance(model_arg, Instance):
|
||||
model_type_info = model_arg.type
|
||||
else:
|
||||
model_type_info = None
|
||||
|
||||
# Figure out each field name passed to fields
|
||||
has_dynamic_column_names = False
|
||||
for field_expr in fields_arg_expr:
|
||||
if isinstance(field_expr, StrExpr):
|
||||
field_name = field_expr.value
|
||||
column_names.append(field_name)
|
||||
# Default to any type
|
||||
column_types[field_name] = any_type
|
||||
|
||||
if model_type_info:
|
||||
resolved_lookup_type = resolve_values_lookup(ctx.api, model_type_info, field_name)
|
||||
if resolved_lookup_type is not None:
|
||||
column_types[field_name] = resolved_lookup_type
|
||||
else:
|
||||
# Dynamic field names are partially supported for values_list, but not values
|
||||
column_names.append(None)
|
||||
has_dynamic_column_names = True
|
||||
|
||||
if method_name == 'values_list':
|
||||
flat = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'flat'))
|
||||
named = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'named'))
|
||||
|
||||
if named and flat:
|
||||
api.fail("'flat' and 'named' can't be used together.", ctx.context)
|
||||
return ret
|
||||
elif named:
|
||||
if fill_column_types and not has_dynamic_column_names:
|
||||
row_arg = helpers.make_named_tuple(api, fields=column_types, name="Row")
|
||||
else:
|
||||
row_arg = helpers.make_named_tuple(api, fields=OrderedDict(), name="Row")
|
||||
elif flat:
|
||||
if len(ctx.args[0]) > 1:
|
||||
api.fail("'flat' is not valid when values_list is called with more than one field.", ctx.context)
|
||||
return ret
|
||||
if fill_column_types and not has_dynamic_column_names:
|
||||
# Grab first element
|
||||
row_arg = column_types[column_names[0]]
|
||||
else:
|
||||
row_arg = any_type
|
||||
else:
|
||||
if fill_column_types:
|
||||
args = [
|
||||
# Fallback to Any if the column name is unknown (e.g. dynamic)
|
||||
column_types.get(column_name, any_type) if column_name is not None else any_type
|
||||
for column_name in column_names
|
||||
]
|
||||
else:
|
||||
args = [any_type]
|
||||
row_arg = helpers.make_tuple(api, fields=args)
|
||||
elif method_name == 'values':
|
||||
expression_arg_names = ctx.arg_names[ctx.callee_arg_names.index('expressions')]
|
||||
for expression_name in expression_arg_names:
|
||||
# Arbitrary additional annotation expressions are supported, but they all have type Any for now
|
||||
column_names.append(expression_name)
|
||||
column_types[expression_name] = any_type
|
||||
|
||||
if fill_column_types and not has_dynamic_column_names:
|
||||
row_arg = helpers.make_typeddict(api, fields=column_types, required_keys=set())
|
||||
else:
|
||||
return ctx.default_return_type
|
||||
else:
|
||||
raise Exception(f"extract_proper_type_for_values_list doesn't support method {method_name}")
|
||||
|
||||
new_type_args = [model_arg, row_arg]
|
||||
return helpers.reparametrize_instance(ret, new_type_args)
|
||||
|
||||
|
||||
def resolve_values_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, lookup: str) -> Optional[Type]:
|
||||
"""Resolves a values/values_list lookup if possible, to a Type."""
|
||||
try:
|
||||
nodes = resolve_lookup(api, model_type_info, lookup)
|
||||
except LookupException:
|
||||
nodes = []
|
||||
|
||||
if not nodes:
|
||||
return None
|
||||
|
||||
make_optional = False
|
||||
|
||||
for node in nodes:
|
||||
if isinstance(node, RelatedModelNode) and node.is_nullable:
|
||||
# All lookups following a relation which is nullable should be optional
|
||||
make_optional = True
|
||||
|
||||
node = nodes[-1]
|
||||
|
||||
node_type = node.typ
|
||||
if isinstance(node, RelatedModelNode):
|
||||
# Related models used in values/values_list get resolved to the primary key of the related model.
|
||||
# So, we lookup the pk of that model.
|
||||
pk_lookup_nodes = resolve_lookup(api, node_type.type, "pk")
|
||||
if not pk_lookup_nodes:
|
||||
return None
|
||||
node_type = pk_lookup_nodes[0].typ
|
||||
if make_optional:
|
||||
return helpers.make_optional(node_type)
|
||||
else:
|
||||
return node_type
|
||||
Reference in New Issue
Block a user