From 0e72b2e6fc50c69c048e1bc4fa2db497f3432cc3 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Thu, 18 Jul 2019 02:29:36 +0300 Subject: [PATCH] more values(), values_list() cases --- mypy_django_plugin/django/context.py | 14 +++-- mypy_django_plugin/lib/helpers.py | 25 +++++++- mypy_django_plugin/transformers/fields.py | 39 ++++++++++--- .../transformers/init_create.py | 4 +- mypy_django_plugin/transformers/querysets.py | 58 +++++++++++-------- test-data/typecheck/fields/test_related.yml | 22 +++++++ .../managers/querysets/test_values.yml | 29 ++++++++++ .../managers/querysets/test_values_list.yml | 38 +++++++++++- 8 files changed, 187 insertions(+), 42 deletions(-) diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index c585558..9d4b14d 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -103,7 +103,10 @@ class DjangoFieldsContext: class DjangoLookupsContext: - def resolve_lookup(self, model_cls: Type[Model], lookup: str) -> Any: + def __init__(self, django_context: 'DjangoContext'): + self.django_context = django_context + + def resolve_lookup(self, model_cls: Type[Model], lookup: str) -> Optional[Field]: query = Query(model_cls) lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup) if lookup_parts: @@ -111,8 +114,11 @@ class DjangoLookupsContext: currently_observed_model = model_cls current_field = None - for field_name in field_parts: - current_field = currently_observed_model._meta.get_field(field_name) + for field_part in field_parts: + if field_part == 'pk': + return self.django_context.get_primary_key_field(currently_observed_model) + + current_field = currently_observed_model._meta.get_field(field_part) if isinstance(current_field, RelatedField): currently_observed_model = current_field.related_model @@ -123,7 +129,7 @@ class DjangoContext: def __init__(self, plugin_toml_config: Optional[Dict[str, Any]]) -> None: self.config = DjangoPluginConfig() self.fields_context = DjangoFieldsContext(self) - self.lookups_context = DjangoLookupsContext() + self.lookups_context = DjangoLookupsContext(self) self.django_settings_module = None if plugin_toml_config: diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index cc7d95a..39c8bc4 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -1,14 +1,17 @@ from collections import OrderedDict -from typing import Dict, List, Optional, Set, Union, Any +from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING from mypy import checker from mypy.checker import TypeChecker from mypy.mro import calculate_mro -from mypy.nodes import Block, ClassDef, Expression, GDEF, MDEF, MypyFile, NameExpr, SymbolNode, SymbolTable, SymbolTableNode, \ - TypeInfo, Var +from mypy.nodes import Block, ClassDef, Expression, GDEF, MDEF, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, \ + SymbolTableNode, TypeInfo, Var, MemberExpr from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType +if TYPE_CHECKING: + from mypy_django_plugin.django.context import DjangoContext + def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]: return model_info.metadata.setdefault('django', {}) @@ -202,3 +205,19 @@ def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyTy object_type = api.named_generic_type('mypy_extensions._TypedDict', []) typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type) return typed_dict_type + + +def resolve_string_attribute_value(attr_expr: Expression, ctx: Union[FunctionContext, MethodContext], + django_context: 'DjangoContext') -> Optional[str]: + if isinstance(attr_expr, StrExpr): + return attr_expr.value + + # support extracting from settings, in general case it's unresolvable yet + if isinstance(attr_expr, MemberExpr): + member_name = attr_expr.name + if isinstance(attr_expr.expr, NameExpr) and attr_expr.expr.fullname == 'django.conf.settings': + if hasattr(django_context.settings, member_name): + return getattr(django_context.settings, member_name) + + ctx.api.fail(f'Expression of type {type(attr_expr).__name__!r} is not supported', ctx.context) + return None diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index 0467f3f..c5320b6 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -1,9 +1,9 @@ from typing import Optional, Tuple, cast from mypy.checker import TypeChecker -from mypy.nodes import StrExpr, TypeInfo +from mypy.nodes import StrExpr, TypeInfo, Expression from mypy.plugin import FunctionContext -from mypy.types import AnyType, CallableType, Instance, Type as MypyType, UnionType +from mypy.types import AnyType, CallableType, Instance, Type as MypyType, UnionType, TypeOfAny from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.lib import fullnames, helpers @@ -77,34 +77,55 @@ def convert_any_to_type(typ: MypyType, replacement_type: MypyType) -> MypyType: return typ -def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoContext) -> str: +def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoContext) -> Optional[str]: to_arg_type = helpers.get_call_argument_type_by_name(ctx, 'to') if isinstance(to_arg_type, CallableType): assert isinstance(to_arg_type.ret_type, Instance) return to_arg_type.ret_type.type.fullname() - to_arg_expr = helpers.get_call_argument_by_name(ctx, 'to') - if not isinstance(to_arg_expr, StrExpr): - raise helpers.IncompleteDefnException(f'Not a string: {to_arg_expr}') - outer_model_info = ctx.api.tscope.classes[-1] assert isinstance(outer_model_info, TypeInfo) - model_string = to_arg_expr.value + to_arg_expr = helpers.get_call_argument_by_name(ctx, 'to') + + model_string = helpers.resolve_string_attribute_value(to_arg_expr, ctx, django_context) + if model_string is None: + # unresolvable + return None + if model_string == 'self': return outer_model_info.fullname() if '.' not in model_string: # same file class return outer_model_info.module_name + '.' + model_string - model_cls = django_context.apps_registry.get_model(model_string) + app_label, model_name = model_string.split('.') + if app_label not in django_context.apps_registry.app_configs: + ctx.api.fail(f'No installed app with label {app_label!r}', ctx.context) + return None + + try: + model_cls = django_context.apps_registry.get_model(app_label, model_name) + except LookupError as exc: + # no model in app + ctx.api.fail(exc.args[0], ctx.context) + return None + model_fullname = helpers.get_class_fullname(model_cls) return model_fullname def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: referred_to_fullname = get_referred_to_model_fullname(ctx, django_context) + if referred_to_fullname is None: + return AnyType(TypeOfAny.from_error) + referred_to_typeinfo = helpers.lookup_fully_qualified_generic(referred_to_fullname, ctx.api.modules) + if referred_to_typeinfo is None: + ctx.api.fail(f'Cannot resolve {referred_to_fullname!r}. Please, report it to package developers.', + ctx.context) + return AnyType(TypeOfAny.from_error) + assert isinstance(referred_to_typeinfo, TypeInfo), f'Cannot resolve {referred_to_fullname!r}' referred_to_type = Instance(referred_to_typeinfo, []) diff --git a/mypy_django_plugin/transformers/init_create.py b/mypy_django_plugin/transformers/init_create.py index 83bfa2a..0c37df9 100644 --- a/mypy_django_plugin/transformers/init_create.py +++ b/mypy_django_plugin/transformers/init_create.py @@ -59,7 +59,9 @@ def redefine_and_typecheck_model_init(ctx: FunctionContext, django_context: Djan def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: DjangoContext) -> MypyType: - isinstance(ctx.default_return_type, Instance) + if not isinstance(ctx.default_return_type, Instance): + # only work with ctx.default_return_type = model Instance + return ctx.default_return_type model_fullname = ctx.default_return_type.type.fullname() model_cls = django_context.get_model_class_by_fullname(model_fullname) diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index e0d8abd..d8119cc 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -1,10 +1,10 @@ from collections import OrderedDict -from typing import Optional, Tuple, Type +from typing import Optional, Tuple, Type, Sequence, List, Union from django.core.exceptions import FieldError from django.db.models.base import Model from django.db.models.fields.related import ForeignKey -from mypy.nodes import NameExpr +from mypy.nodes import NameExpr, Expression from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny @@ -56,28 +56,31 @@ def get_lookup_field_get_type(ctx: MethodContext, django_context: DjangoContext, return None field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, method) - return lookup_field.attname, field_get_type + return lookup, field_get_type def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], flat: bool, named: bool) -> MypyType: - field_lookups = [expr.value for expr in ctx.args[0]] + field_lookups = resolve_field_lookups(ctx.args[0], ctx, django_context) + if field_lookups is None: + return AnyType(TypeOfAny.from_error) + if len(field_lookups) == 0: if flat: primary_key_field = django_context.get_primary_key_field(model_cls) - _, field_get_type = get_lookup_field_get_type(ctx, django_context, model_cls, + _, column_type = get_lookup_field_get_type(ctx, django_context, model_cls, primary_key_field.attname, 'values_list') - return field_get_type + return column_type elif named: column_types = OrderedDict() for field in django_context.get_model_fields(model_cls): - field_get_type = django_context.fields_context.get_field_get_type(ctx.api, field, 'values_list') - column_types[field.attname] = field_get_type + column_type = django_context.fields_context.get_field_get_type(ctx.api, field, 'values_list') + column_types[field.attname] = column_type return helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types) else: # flat=False, named=False, all fields field_lookups = [] - for field in model_cls._meta.get_fields(): + for field in django_context.get_model_fields(model_cls): field_lookups.append(field.attname) if len(field_lookups) > 1 and flat: @@ -89,8 +92,9 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values_list') if result is None: return AnyType(TypeOfAny.from_error) - field_name, field_get_type = result - column_types[field_name] = field_get_type + + column_name, column_type = result + column_types[column_name] = column_type if flat: assert len(column_types) == 1 @@ -133,6 +137,17 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) +def resolve_field_lookups(lookup_exprs: Sequence[Expression], ctx: Union[FunctionContext, MethodContext], + django_context: DjangoContext) -> Optional[List[str]]: + field_lookups = [] + for field_lookup_expr in lookup_exprs: + field_lookup = helpers.resolve_string_attribute_value(field_lookup_expr, ctx, django_context) + if field_lookup is None: + return None + field_lookups.append(field_lookup) + return field_lookups + + def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType: assert isinstance(ctx.type, Instance) assert isinstance(ctx.type.args[0], Instance) @@ -142,25 +157,22 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan if model_cls is None: return ctx.default_return_type - field_lookups = [expr.value for expr in ctx.args[0]] + field_lookups = resolve_field_lookups(ctx.args[0], ctx, django_context) + if field_lookups is None: + return AnyType(TypeOfAny.from_error) + if len(field_lookups) == 0: - for field in model_cls._meta.get_fields(): + for field in django_context.get_model_fields(model_cls): field_lookups.append(field.attname) column_types = OrderedDict() for field_lookup in field_lookups: - try: - lookup_field = django_context.lookups_context.resolve_lookup(model_cls, field_lookup) - except FieldError as exc: - ctx.api.fail(exc.args[0], ctx.context) + result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values') + if result is None: return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)]) - field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, 'values') - field_name = lookup_field.attname - if isinstance(lookup_field, ForeignKey) and field_lookup == lookup_field.name: - field_name = lookup_field.name - - column_types[field_name] = field_get_type + column_name, column_type = result + column_types[column_name] = column_type row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys())) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) diff --git a/test-data/typecheck/fields/test_related.yml b/test-data/typecheck/fields/test_related.yml index 4dcdfcc..da574ae 100644 --- a/test-data/typecheck/fields/test_related.yml +++ b/test-data/typecheck/fields/test_related.yml @@ -412,3 +412,25 @@ related_name='+') publisher2 = models.ForeignKey(to=Publisher, on_delete=models.CASCADE, related_name='books2') + +- case: to_parameter_could_be_resolved_if_passed_from_settings + main: | + from myapp.models import Book + book = Book() + reveal_type(book.publisher) # N: Revealed type is 'myapp.models.Publisher*' + installed_apps: + - myapp + additional_settings: + - BOOK_RELATED_MODEL='myapp.Publisher' + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.conf import settings + from django.db import models + + class Publisher(models.Model): + pass + class Book(models.Model): + publisher = models.ForeignKey(to=settings.BOOK_RELATED_MODEL, on_delete=models.CASCADE, + related_name='books') diff --git a/test-data/typecheck/managers/querysets/test_values.yml b/test-data/typecheck/managers/querysets/test_values.yml index 2fa7f5d..1632abc 100644 --- a/test-data/typecheck/managers/querysets/test_values.yml +++ b/test-data/typecheck/managers/querysets/test_values.yml @@ -5,6 +5,10 @@ reveal_type(values) # N: Revealed type is 'TypedDict({'num_posts': builtins.int, 'text': builtins.str})' reveal_type(values["num_posts"]) # N: Revealed type is 'builtins.int' reveal_type(values["text"]) # N: Revealed type is 'builtins.str' + + values_pk = Blog.objects.values('pk').get() + reveal_type(values_pk) # N: Revealed type is 'TypedDict({'pk': builtins.int})' + reveal_type(values_pk["pk"]) # N: Revealed type is 'builtins.int' installed_apps: - myapp files: @@ -31,6 +35,8 @@ - path: myapp/models.py content: | from django.db import models + class Publisher(models.Model): + pass class Blog(models.Model): num_posts = models.IntegerField() text = models.CharField(max_length=100) @@ -60,3 +66,26 @@ pass class Blog(models.Model): publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) + +- case: values_with_related_model_fields + main: | + from myapp.models import Entry + values = Entry.objects.values('blog__num_articles', 'blog__publisher__name').get() + reveal_type(values) # N: Revealed type is 'TypedDict({'blog__num_articles': builtins.int, 'blog__publisher__name': builtins.str})' + + pk_values = Entry.objects.values('blog__pk', 'blog__publisher__pk').get() + reveal_type(pk_values) # N: Revealed type is 'TypedDict({'blog__pk': builtins.int, 'blog__publisher__pk': builtins.int})' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class Publisher(models.Model): + name = models.CharField(max_length=100) + class Blog(models.Model): + num_articles = models.IntegerField() + publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) + class Entry(models.Model): + blog = models.ForeignKey(Blog, on_delete=models.CASCADE) \ No newline at end of file diff --git a/test-data/typecheck/managers/querysets/test_values_list.yml b/test-data/typecheck/managers/querysets/test_values_list.yml index 5e31e1e..1196547 100644 --- a/test-data/typecheck/managers/querysets/test_values_list.yml +++ b/test-data/typecheck/managers/querysets/test_values_list.yml @@ -1,4 +1,4 @@ -- case: values_list_simple_field +- case: values_list_simple_field_returns_queryset_of_tuples main: | from myapp.models import MyUser reveal_type(MyUser.objects.values_list('name').get()) # N: Revealed type is 'Tuple[builtins.str]' @@ -11,6 +11,10 @@ # no fields specified return all fields all_values_tuple = MyUser.objects.values_list().get() reveal_type(all_values_tuple) # N: Revealed type is 'Tuple[builtins.int, builtins.str, builtins.int]' + + # pk as field + pk_values = MyUser.objects.values_list('pk').get() + reveal_type(pk_values) # N: # N: Revealed type is 'Tuple[builtins.int]' installed_apps: - myapp files: @@ -84,6 +88,11 @@ reveal_type(all_values_named_tuple.name) # N: Revealed type is 'builtins.str' reveal_type(all_values_named_tuple.age) # N: Revealed type is 'builtins.int' reveal_type(all_values_named_tuple.is_admin) # N: Revealed type is 'builtins.bool' + + # pk as field + pk_values = MyUser.objects.values_list('pk', named=True).get() + reveal_type(pk_values) # N: Revealed type is 'Tuple[builtins.int, fallback=main.Row2]' + reveal_type(pk_values.pk) # N: # N: Revealed type is 'builtins.int' installed_apps: - myapp files: @@ -139,4 +148,29 @@ class Publisher(models.Model): pass class Blog(models.Model): - publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) \ No newline at end of file + publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) + +- case: named_true_with_related_model_fields + main: | + from myapp.models import Entry + values = Entry.objects.values_list('blog__num_articles', 'blog__publisher__name', named=True).get() + reveal_type(values.blog__num_articles) # N: Revealed type is 'builtins.int' + reveal_type(values.blog__publisher__name) # N: Revealed type is 'builtins.str' + + pk_values = Entry.objects.values_list('blog__pk', 'blog__publisher__pk', named=True).get() + reveal_type(pk_values.blog__pk) # N: Revealed type is 'builtins.int' + reveal_type(pk_values.blog__publisher__pk) # N: Revealed type is 'builtins.int' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class Publisher(models.Model): + name = models.CharField(max_length=100) + class Blog(models.Model): + num_articles = models.IntegerField() + publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) + class Entry(models.Model): + blog = models.ForeignKey(Blog, on_delete=models.CASCADE)