more values(), values_list() cases

This commit is contained in:
Maxim Kurnikov
2019-07-18 02:29:36 +03:00
parent b81fbdeaa9
commit 0e72b2e6fc
8 changed files with 187 additions and 42 deletions

View File

@@ -103,7 +103,10 @@ class DjangoFieldsContext:
class DjangoLookupsContext:
def resolve_lookup(self, model_cls: Type[Model], lookup: str) -> Any:
def __init__(self, django_context: 'DjangoContext'):
self.django_context = django_context
def resolve_lookup(self, model_cls: Type[Model], lookup: str) -> Optional[Field]:
query = Query(model_cls)
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
if lookup_parts:
@@ -111,8 +114,11 @@ class DjangoLookupsContext:
currently_observed_model = model_cls
current_field = None
for field_name in field_parts:
current_field = currently_observed_model._meta.get_field(field_name)
for field_part in field_parts:
if field_part == 'pk':
return self.django_context.get_primary_key_field(currently_observed_model)
current_field = currently_observed_model._meta.get_field(field_part)
if isinstance(current_field, RelatedField):
currently_observed_model = current_field.related_model
@@ -123,7 +129,7 @@ class DjangoContext:
def __init__(self, plugin_toml_config: Optional[Dict[str, Any]]) -> None:
self.config = DjangoPluginConfig()
self.fields_context = DjangoFieldsContext(self)
self.lookups_context = DjangoLookupsContext()
self.lookups_context = DjangoLookupsContext(self)
self.django_settings_module = None
if plugin_toml_config:

View File

@@ -1,14 +1,17 @@
from collections import OrderedDict
from typing import Dict, List, Optional, Set, Union, Any
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
from mypy import checker
from mypy.checker import TypeChecker
from mypy.mro import calculate_mro
from mypy.nodes import Block, ClassDef, Expression, GDEF, MDEF, MypyFile, NameExpr, SymbolNode, SymbolTable, SymbolTableNode, \
TypeInfo, Var
from mypy.nodes import Block, ClassDef, Expression, GDEF, MDEF, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, \
SymbolTableNode, TypeInfo, Var, MemberExpr
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext
from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType
if TYPE_CHECKING:
from mypy_django_plugin.django.context import DjangoContext
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault('django', {})
@@ -202,3 +205,19 @@ def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyTy
object_type = api.named_generic_type('mypy_extensions._TypedDict', [])
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
return typed_dict_type
def resolve_string_attribute_value(attr_expr: Expression, ctx: Union[FunctionContext, MethodContext],
django_context: 'DjangoContext') -> Optional[str]:
if isinstance(attr_expr, StrExpr):
return attr_expr.value
# support extracting from settings, in general case it's unresolvable yet
if isinstance(attr_expr, MemberExpr):
member_name = attr_expr.name
if isinstance(attr_expr.expr, NameExpr) and attr_expr.expr.fullname == 'django.conf.settings':
if hasattr(django_context.settings, member_name):
return getattr(django_context.settings, member_name)
ctx.api.fail(f'Expression of type {type(attr_expr).__name__!r} is not supported', ctx.context)
return None

View File

@@ -1,9 +1,9 @@
from typing import Optional, Tuple, cast
from mypy.checker import TypeChecker
from mypy.nodes import StrExpr, TypeInfo
from mypy.nodes import StrExpr, TypeInfo, Expression
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, Type as MypyType, UnionType
from mypy.types import AnyType, CallableType, Instance, Type as MypyType, UnionType, TypeOfAny
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
@@ -77,34 +77,55 @@ def convert_any_to_type(typ: MypyType, replacement_type: MypyType) -> MypyType:
return typ
def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoContext) -> str:
def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoContext) -> Optional[str]:
to_arg_type = helpers.get_call_argument_type_by_name(ctx, 'to')
if isinstance(to_arg_type, CallableType):
assert isinstance(to_arg_type.ret_type, Instance)
return to_arg_type.ret_type.type.fullname()
to_arg_expr = helpers.get_call_argument_by_name(ctx, 'to')
if not isinstance(to_arg_expr, StrExpr):
raise helpers.IncompleteDefnException(f'Not a string: {to_arg_expr}')
outer_model_info = ctx.api.tscope.classes[-1]
assert isinstance(outer_model_info, TypeInfo)
model_string = to_arg_expr.value
to_arg_expr = helpers.get_call_argument_by_name(ctx, 'to')
model_string = helpers.resolve_string_attribute_value(to_arg_expr, ctx, django_context)
if model_string is None:
# unresolvable
return None
if model_string == 'self':
return outer_model_info.fullname()
if '.' not in model_string:
# same file class
return outer_model_info.module_name + '.' + model_string
model_cls = django_context.apps_registry.get_model(model_string)
app_label, model_name = model_string.split('.')
if app_label not in django_context.apps_registry.app_configs:
ctx.api.fail(f'No installed app with label {app_label!r}', ctx.context)
return None
try:
model_cls = django_context.apps_registry.get_model(app_label, model_name)
except LookupError as exc:
# no model in app
ctx.api.fail(exc.args[0], ctx.context)
return None
model_fullname = helpers.get_class_fullname(model_cls)
return model_fullname
def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
referred_to_fullname = get_referred_to_model_fullname(ctx, django_context)
if referred_to_fullname is None:
return AnyType(TypeOfAny.from_error)
referred_to_typeinfo = helpers.lookup_fully_qualified_generic(referred_to_fullname, ctx.api.modules)
if referred_to_typeinfo is None:
ctx.api.fail(f'Cannot resolve {referred_to_fullname!r}. Please, report it to package developers.',
ctx.context)
return AnyType(TypeOfAny.from_error)
assert isinstance(referred_to_typeinfo, TypeInfo), f'Cannot resolve {referred_to_fullname!r}'
referred_to_type = Instance(referred_to_typeinfo, [])

View File

@@ -59,7 +59,9 @@ def redefine_and_typecheck_model_init(ctx: FunctionContext, django_context: Djan
def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
isinstance(ctx.default_return_type, Instance)
if not isinstance(ctx.default_return_type, Instance):
# only work with ctx.default_return_type = model Instance
return ctx.default_return_type
model_fullname = ctx.default_return_type.type.fullname()
model_cls = django_context.get_model_class_by_fullname(model_fullname)

View File

@@ -1,10 +1,10 @@
from collections import OrderedDict
from typing import Optional, Tuple, Type
from typing import Optional, Tuple, Type, Sequence, List, Union
from django.core.exceptions import FieldError
from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey
from mypy.nodes import NameExpr
from mypy.nodes import NameExpr, Expression
from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext
from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny
@@ -56,28 +56,31 @@ def get_lookup_field_get_type(ctx: MethodContext, django_context: DjangoContext,
return None
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, method)
return lookup_field.attname, field_get_type
return lookup, field_get_type
def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
flat: bool, named: bool) -> MypyType:
field_lookups = [expr.value for expr in ctx.args[0]]
field_lookups = resolve_field_lookups(ctx.args[0], ctx, django_context)
if field_lookups is None:
return AnyType(TypeOfAny.from_error)
if len(field_lookups) == 0:
if flat:
primary_key_field = django_context.get_primary_key_field(model_cls)
_, field_get_type = get_lookup_field_get_type(ctx, django_context, model_cls,
_, column_type = get_lookup_field_get_type(ctx, django_context, model_cls,
primary_key_field.attname, 'values_list')
return field_get_type
return column_type
elif named:
column_types = OrderedDict()
for field in django_context.get_model_fields(model_cls):
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, field, 'values_list')
column_types[field.attname] = field_get_type
column_type = django_context.fields_context.get_field_get_type(ctx.api, field, 'values_list')
column_types[field.attname] = column_type
return helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types)
else:
# flat=False, named=False, all fields
field_lookups = []
for field in model_cls._meta.get_fields():
for field in django_context.get_model_fields(model_cls):
field_lookups.append(field.attname)
if len(field_lookups) > 1 and flat:
@@ -89,8 +92,9 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values_list')
if result is None:
return AnyType(TypeOfAny.from_error)
field_name, field_get_type = result
column_types[field_name] = field_get_type
column_name, column_type = result
column_types[column_name] = column_type
if flat:
assert len(column_types) == 1
@@ -133,6 +137,17 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
def resolve_field_lookups(lookup_exprs: Sequence[Expression], ctx: Union[FunctionContext, MethodContext],
django_context: DjangoContext) -> Optional[List[str]]:
field_lookups = []
for field_lookup_expr in lookup_exprs:
field_lookup = helpers.resolve_string_attribute_value(field_lookup_expr, ctx, django_context)
if field_lookup is None:
return None
field_lookups.append(field_lookup)
return field_lookups
def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
assert isinstance(ctx.type, Instance)
assert isinstance(ctx.type.args[0], Instance)
@@ -142,25 +157,22 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
if model_cls is None:
return ctx.default_return_type
field_lookups = [expr.value for expr in ctx.args[0]]
field_lookups = resolve_field_lookups(ctx.args[0], ctx, django_context)
if field_lookups is None:
return AnyType(TypeOfAny.from_error)
if len(field_lookups) == 0:
for field in model_cls._meta.get_fields():
for field in django_context.get_model_fields(model_cls):
field_lookups.append(field.attname)
column_types = OrderedDict()
for field_lookup in field_lookups:
try:
lookup_field = django_context.lookups_context.resolve_lookup(model_cls, field_lookup)
except FieldError as exc:
ctx.api.fail(exc.args[0], ctx.context)
result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values')
if result is None:
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, 'values')
field_name = lookup_field.attname
if isinstance(lookup_field, ForeignKey) and field_lookup == lookup_field.name:
field_name = lookup_field.name
column_types[field_name] = field_get_type
column_name, column_type = result
column_types[column_name] = column_type
row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys()))
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])