more values(), values_list() cases

This commit is contained in:
Maxim Kurnikov
2019-07-18 02:29:36 +03:00
parent b81fbdeaa9
commit 0e72b2e6fc
8 changed files with 187 additions and 42 deletions

View File

@@ -103,7 +103,10 @@ class DjangoFieldsContext:
class DjangoLookupsContext: class DjangoLookupsContext:
def resolve_lookup(self, model_cls: Type[Model], lookup: str) -> Any: def __init__(self, django_context: 'DjangoContext'):
self.django_context = django_context
def resolve_lookup(self, model_cls: Type[Model], lookup: str) -> Optional[Field]:
query = Query(model_cls) query = Query(model_cls)
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup) lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
if lookup_parts: if lookup_parts:
@@ -111,8 +114,11 @@ class DjangoLookupsContext:
currently_observed_model = model_cls currently_observed_model = model_cls
current_field = None current_field = None
for field_name in field_parts: for field_part in field_parts:
current_field = currently_observed_model._meta.get_field(field_name) if field_part == 'pk':
return self.django_context.get_primary_key_field(currently_observed_model)
current_field = currently_observed_model._meta.get_field(field_part)
if isinstance(current_field, RelatedField): if isinstance(current_field, RelatedField):
currently_observed_model = current_field.related_model currently_observed_model = current_field.related_model
@@ -123,7 +129,7 @@ class DjangoContext:
def __init__(self, plugin_toml_config: Optional[Dict[str, Any]]) -> None: def __init__(self, plugin_toml_config: Optional[Dict[str, Any]]) -> None:
self.config = DjangoPluginConfig() self.config = DjangoPluginConfig()
self.fields_context = DjangoFieldsContext(self) self.fields_context = DjangoFieldsContext(self)
self.lookups_context = DjangoLookupsContext() self.lookups_context = DjangoLookupsContext(self)
self.django_settings_module = None self.django_settings_module = None
if plugin_toml_config: if plugin_toml_config:

View File

@@ -1,14 +1,17 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Optional, Set, Union, Any from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
from mypy import checker from mypy import checker
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.mro import calculate_mro from mypy.mro import calculate_mro
from mypy.nodes import Block, ClassDef, Expression, GDEF, MDEF, MypyFile, NameExpr, SymbolNode, SymbolTable, SymbolTableNode, \ from mypy.nodes import Block, ClassDef, Expression, GDEF, MDEF, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, \
TypeInfo, Var SymbolTableNode, TypeInfo, Var, MemberExpr
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext
from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType
if TYPE_CHECKING:
from mypy_django_plugin.django.context import DjangoContext
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]: def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault('django', {}) return model_info.metadata.setdefault('django', {})
@@ -202,3 +205,19 @@ def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyTy
object_type = api.named_generic_type('mypy_extensions._TypedDict', []) object_type = api.named_generic_type('mypy_extensions._TypedDict', [])
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type) typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
return typed_dict_type return typed_dict_type
def resolve_string_attribute_value(attr_expr: Expression, ctx: Union[FunctionContext, MethodContext],
django_context: 'DjangoContext') -> Optional[str]:
if isinstance(attr_expr, StrExpr):
return attr_expr.value
# support extracting from settings, in general case it's unresolvable yet
if isinstance(attr_expr, MemberExpr):
member_name = attr_expr.name
if isinstance(attr_expr.expr, NameExpr) and attr_expr.expr.fullname == 'django.conf.settings':
if hasattr(django_context.settings, member_name):
return getattr(django_context.settings, member_name)
ctx.api.fail(f'Expression of type {type(attr_expr).__name__!r} is not supported', ctx.context)
return None

View File

@@ -1,9 +1,9 @@
from typing import Optional, Tuple, cast from typing import Optional, Tuple, cast
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.nodes import StrExpr, TypeInfo from mypy.nodes import StrExpr, TypeInfo, Expression
from mypy.plugin import FunctionContext from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, Type as MypyType, UnionType from mypy.types import AnyType, CallableType, Instance, Type as MypyType, UnionType, TypeOfAny
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib import fullnames, helpers
@@ -77,34 +77,55 @@ def convert_any_to_type(typ: MypyType, replacement_type: MypyType) -> MypyType:
return typ return typ
def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoContext) -> str: def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoContext) -> Optional[str]:
to_arg_type = helpers.get_call_argument_type_by_name(ctx, 'to') to_arg_type = helpers.get_call_argument_type_by_name(ctx, 'to')
if isinstance(to_arg_type, CallableType): if isinstance(to_arg_type, CallableType):
assert isinstance(to_arg_type.ret_type, Instance) assert isinstance(to_arg_type.ret_type, Instance)
return to_arg_type.ret_type.type.fullname() return to_arg_type.ret_type.type.fullname()
to_arg_expr = helpers.get_call_argument_by_name(ctx, 'to')
if not isinstance(to_arg_expr, StrExpr):
raise helpers.IncompleteDefnException(f'Not a string: {to_arg_expr}')
outer_model_info = ctx.api.tscope.classes[-1] outer_model_info = ctx.api.tscope.classes[-1]
assert isinstance(outer_model_info, TypeInfo) assert isinstance(outer_model_info, TypeInfo)
model_string = to_arg_expr.value to_arg_expr = helpers.get_call_argument_by_name(ctx, 'to')
model_string = helpers.resolve_string_attribute_value(to_arg_expr, ctx, django_context)
if model_string is None:
# unresolvable
return None
if model_string == 'self': if model_string == 'self':
return outer_model_info.fullname() return outer_model_info.fullname()
if '.' not in model_string: if '.' not in model_string:
# same file class # same file class
return outer_model_info.module_name + '.' + model_string return outer_model_info.module_name + '.' + model_string
model_cls = django_context.apps_registry.get_model(model_string) app_label, model_name = model_string.split('.')
if app_label not in django_context.apps_registry.app_configs:
ctx.api.fail(f'No installed app with label {app_label!r}', ctx.context)
return None
try:
model_cls = django_context.apps_registry.get_model(app_label, model_name)
except LookupError as exc:
# no model in app
ctx.api.fail(exc.args[0], ctx.context)
return None
model_fullname = helpers.get_class_fullname(model_cls) model_fullname = helpers.get_class_fullname(model_cls)
return model_fullname return model_fullname
def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
referred_to_fullname = get_referred_to_model_fullname(ctx, django_context) referred_to_fullname = get_referred_to_model_fullname(ctx, django_context)
if referred_to_fullname is None:
return AnyType(TypeOfAny.from_error)
referred_to_typeinfo = helpers.lookup_fully_qualified_generic(referred_to_fullname, ctx.api.modules) referred_to_typeinfo = helpers.lookup_fully_qualified_generic(referred_to_fullname, ctx.api.modules)
if referred_to_typeinfo is None:
ctx.api.fail(f'Cannot resolve {referred_to_fullname!r}. Please, report it to package developers.',
ctx.context)
return AnyType(TypeOfAny.from_error)
assert isinstance(referred_to_typeinfo, TypeInfo), f'Cannot resolve {referred_to_fullname!r}' assert isinstance(referred_to_typeinfo, TypeInfo), f'Cannot resolve {referred_to_fullname!r}'
referred_to_type = Instance(referred_to_typeinfo, []) referred_to_type = Instance(referred_to_typeinfo, [])

View File

@@ -59,7 +59,9 @@ def redefine_and_typecheck_model_init(ctx: FunctionContext, django_context: Djan
def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: DjangoContext) -> MypyType: def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
isinstance(ctx.default_return_type, Instance) if not isinstance(ctx.default_return_type, Instance):
# only work with ctx.default_return_type = model Instance
return ctx.default_return_type
model_fullname = ctx.default_return_type.type.fullname() model_fullname = ctx.default_return_type.type.fullname()
model_cls = django_context.get_model_class_by_fullname(model_fullname) model_cls = django_context.get_model_class_by_fullname(model_fullname)

View File

@@ -1,10 +1,10 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Optional, Tuple, Type from typing import Optional, Tuple, Type, Sequence, List, Union
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey from django.db.models.fields.related import ForeignKey
from mypy.nodes import NameExpr from mypy.nodes import NameExpr, Expression
from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext
from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny
@@ -56,28 +56,31 @@ def get_lookup_field_get_type(ctx: MethodContext, django_context: DjangoContext,
return None return None
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, method) field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, method)
return lookup_field.attname, field_get_type return lookup, field_get_type
def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
flat: bool, named: bool) -> MypyType: flat: bool, named: bool) -> MypyType:
field_lookups = [expr.value for expr in ctx.args[0]] field_lookups = resolve_field_lookups(ctx.args[0], ctx, django_context)
if field_lookups is None:
return AnyType(TypeOfAny.from_error)
if len(field_lookups) == 0: if len(field_lookups) == 0:
if flat: if flat:
primary_key_field = django_context.get_primary_key_field(model_cls) primary_key_field = django_context.get_primary_key_field(model_cls)
_, field_get_type = get_lookup_field_get_type(ctx, django_context, model_cls, _, column_type = get_lookup_field_get_type(ctx, django_context, model_cls,
primary_key_field.attname, 'values_list') primary_key_field.attname, 'values_list')
return field_get_type return column_type
elif named: elif named:
column_types = OrderedDict() column_types = OrderedDict()
for field in django_context.get_model_fields(model_cls): for field in django_context.get_model_fields(model_cls):
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, field, 'values_list') column_type = django_context.fields_context.get_field_get_type(ctx.api, field, 'values_list')
column_types[field.attname] = field_get_type column_types[field.attname] = column_type
return helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types) return helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types)
else: else:
# flat=False, named=False, all fields # flat=False, named=False, all fields
field_lookups = [] field_lookups = []
for field in model_cls._meta.get_fields(): for field in django_context.get_model_fields(model_cls):
field_lookups.append(field.attname) field_lookups.append(field.attname)
if len(field_lookups) > 1 and flat: if len(field_lookups) > 1 and flat:
@@ -89,8 +92,9 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values_list') result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values_list')
if result is None: if result is None:
return AnyType(TypeOfAny.from_error) return AnyType(TypeOfAny.from_error)
field_name, field_get_type = result
column_types[field_name] = field_get_type column_name, column_type = result
column_types[column_name] = column_type
if flat: if flat:
assert len(column_types) == 1 assert len(column_types) == 1
@@ -133,6 +137,17 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
def resolve_field_lookups(lookup_exprs: Sequence[Expression], ctx: Union[FunctionContext, MethodContext],
django_context: DjangoContext) -> Optional[List[str]]:
field_lookups = []
for field_lookup_expr in lookup_exprs:
field_lookup = helpers.resolve_string_attribute_value(field_lookup_expr, ctx, django_context)
if field_lookup is None:
return None
field_lookups.append(field_lookup)
return field_lookups
def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType: def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
assert isinstance(ctx.type, Instance) assert isinstance(ctx.type, Instance)
assert isinstance(ctx.type.args[0], Instance) assert isinstance(ctx.type.args[0], Instance)
@@ -142,25 +157,22 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
if model_cls is None: if model_cls is None:
return ctx.default_return_type return ctx.default_return_type
field_lookups = [expr.value for expr in ctx.args[0]] field_lookups = resolve_field_lookups(ctx.args[0], ctx, django_context)
if field_lookups is None:
return AnyType(TypeOfAny.from_error)
if len(field_lookups) == 0: if len(field_lookups) == 0:
for field in model_cls._meta.get_fields(): for field in django_context.get_model_fields(model_cls):
field_lookups.append(field.attname) field_lookups.append(field.attname)
column_types = OrderedDict() column_types = OrderedDict()
for field_lookup in field_lookups: for field_lookup in field_lookups:
try: result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values')
lookup_field = django_context.lookups_context.resolve_lookup(model_cls, field_lookup) if result is None:
except FieldError as exc:
ctx.api.fail(exc.args[0], ctx.context)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)]) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, 'values') column_name, column_type = result
field_name = lookup_field.attname column_types[column_name] = column_type
if isinstance(lookup_field, ForeignKey) and field_lookup == lookup_field.name:
field_name = lookup_field.name
column_types[field_name] = field_get_type
row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys())) row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys()))
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])

View File

@@ -412,3 +412,25 @@
related_name='+') related_name='+')
publisher2 = models.ForeignKey(to=Publisher, on_delete=models.CASCADE, publisher2 = models.ForeignKey(to=Publisher, on_delete=models.CASCADE,
related_name='books2') related_name='books2')
- case: to_parameter_could_be_resolved_if_passed_from_settings
main: |
from myapp.models import Book
book = Book()
reveal_type(book.publisher) # N: Revealed type is 'myapp.models.Publisher*'
installed_apps:
- myapp
additional_settings:
- BOOK_RELATED_MODEL='myapp.Publisher'
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.conf import settings
from django.db import models
class Publisher(models.Model):
pass
class Book(models.Model):
publisher = models.ForeignKey(to=settings.BOOK_RELATED_MODEL, on_delete=models.CASCADE,
related_name='books')

View File

@@ -5,6 +5,10 @@
reveal_type(values) # N: Revealed type is 'TypedDict({'num_posts': builtins.int, 'text': builtins.str})' reveal_type(values) # N: Revealed type is 'TypedDict({'num_posts': builtins.int, 'text': builtins.str})'
reveal_type(values["num_posts"]) # N: Revealed type is 'builtins.int' reveal_type(values["num_posts"]) # N: Revealed type is 'builtins.int'
reveal_type(values["text"]) # N: Revealed type is 'builtins.str' reveal_type(values["text"]) # N: Revealed type is 'builtins.str'
values_pk = Blog.objects.values('pk').get()
reveal_type(values_pk) # N: Revealed type is 'TypedDict({'pk': builtins.int})'
reveal_type(values_pk["pk"]) # N: Revealed type is 'builtins.int'
installed_apps: installed_apps:
- myapp - myapp
files: files:
@@ -31,6 +35,8 @@
- path: myapp/models.py - path: myapp/models.py
content: | content: |
from django.db import models from django.db import models
class Publisher(models.Model):
pass
class Blog(models.Model): class Blog(models.Model):
num_posts = models.IntegerField() num_posts = models.IntegerField()
text = models.CharField(max_length=100) text = models.CharField(max_length=100)
@@ -60,3 +66,26 @@
pass pass
class Blog(models.Model): class Blog(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
- case: values_with_related_model_fields
main: |
from myapp.models import Entry
values = Entry.objects.values('blog__num_articles', 'blog__publisher__name').get()
reveal_type(values) # N: Revealed type is 'TypedDict({'blog__num_articles': builtins.int, 'blog__publisher__name': builtins.str})'
pk_values = Entry.objects.values('blog__pk', 'blog__publisher__pk').get()
reveal_type(pk_values) # N: Revealed type is 'TypedDict({'blog__pk': builtins.int, 'blog__publisher__pk': builtins.int})'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class Publisher(models.Model):
name = models.CharField(max_length=100)
class Blog(models.Model):
num_articles = models.IntegerField()
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
class Entry(models.Model):
blog = models.ForeignKey(Blog, on_delete=models.CASCADE)

View File

@@ -1,4 +1,4 @@
- case: values_list_simple_field - case: values_list_simple_field_returns_queryset_of_tuples
main: | main: |
from myapp.models import MyUser from myapp.models import MyUser
reveal_type(MyUser.objects.values_list('name').get()) # N: Revealed type is 'Tuple[builtins.str]' reveal_type(MyUser.objects.values_list('name').get()) # N: Revealed type is 'Tuple[builtins.str]'
@@ -11,6 +11,10 @@
# no fields specified return all fields # no fields specified return all fields
all_values_tuple = MyUser.objects.values_list().get() all_values_tuple = MyUser.objects.values_list().get()
reveal_type(all_values_tuple) # N: Revealed type is 'Tuple[builtins.int, builtins.str, builtins.int]' reveal_type(all_values_tuple) # N: Revealed type is 'Tuple[builtins.int, builtins.str, builtins.int]'
# pk as field
pk_values = MyUser.objects.values_list('pk').get()
reveal_type(pk_values) # N: # N: Revealed type is 'Tuple[builtins.int]'
installed_apps: installed_apps:
- myapp - myapp
files: files:
@@ -84,6 +88,11 @@
reveal_type(all_values_named_tuple.name) # N: Revealed type is 'builtins.str' reveal_type(all_values_named_tuple.name) # N: Revealed type is 'builtins.str'
reveal_type(all_values_named_tuple.age) # N: Revealed type is 'builtins.int' reveal_type(all_values_named_tuple.age) # N: Revealed type is 'builtins.int'
reveal_type(all_values_named_tuple.is_admin) # N: Revealed type is 'builtins.bool' reveal_type(all_values_named_tuple.is_admin) # N: Revealed type is 'builtins.bool'
# pk as field
pk_values = MyUser.objects.values_list('pk', named=True).get()
reveal_type(pk_values) # N: Revealed type is 'Tuple[builtins.int, fallback=main.Row2]'
reveal_type(pk_values.pk) # N: # N: Revealed type is 'builtins.int'
installed_apps: installed_apps:
- myapp - myapp
files: files:
@@ -139,4 +148,29 @@
class Publisher(models.Model): class Publisher(models.Model):
pass pass
class Blog(models.Model): class Blog(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
- case: named_true_with_related_model_fields
main: |
from myapp.models import Entry
values = Entry.objects.values_list('blog__num_articles', 'blog__publisher__name', named=True).get()
reveal_type(values.blog__num_articles) # N: Revealed type is 'builtins.int'
reveal_type(values.blog__publisher__name) # N: Revealed type is 'builtins.str'
pk_values = Entry.objects.values_list('blog__pk', 'blog__publisher__pk', named=True).get()
reveal_type(pk_values.blog__pk) # N: Revealed type is 'builtins.int'
reveal_type(pk_values.blog__publisher__pk) # N: Revealed type is 'builtins.int'
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class Publisher(models.Model):
name = models.CharField(max_length=100)
class Blog(models.Model):
num_articles = models.IntegerField()
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
class Entry(models.Model):
blog = models.ForeignKey(Blog, on_delete=models.CASCADE)