diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index 1e68763..36ac0f8 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -124,7 +124,7 @@ class DjangoContext: if isinstance(field, ForeignObjectRel): yield field - def get_field_lookup_exact_type(self, api: TypeChecker, field: Field) -> MypyType: + def get_field_lookup_exact_type(self, api: TypeChecker, field: Union[Field, ForeignObjectRel]) -> MypyType: if isinstance(field, (RelatedField, ForeignObjectRel)): related_model_cls = field.related_model primary_key_field = self.get_primary_key_field(related_model_cls) @@ -134,10 +134,8 @@ class DjangoContext: if rel_model_info is None: return AnyType(TypeOfAny.explicit) - model_and_primary_key_type = UnionType.make_union([Instance(rel_model_info, []), - primary_key_type]) + model_and_primary_key_type = UnionType.make_union([Instance(rel_model_info, []), primary_key_type]) return helpers.make_optional(model_and_primary_key_type) - # return helpers.make_optional(Instance(rel_model_info, [])) field_info = helpers.lookup_class_typeinfo(api, field.__class__) if field_info is None: @@ -228,21 +226,22 @@ class DjangoContext: attname = field.attname return attname - def get_field_nullability(self, field: Field, method: Optional[str]) -> bool: + def get_field_nullability(self, field: Union[Field, ForeignObjectRel], method: Optional[str]) -> bool: nullable = field.null if not nullable and isinstance(field, CharField) and field.blank: return True if method == '__init__': - if field.primary_key or isinstance(field, ForeignKey): + if ((isinstance(field, Field) and field.primary_key) + or isinstance(field, ForeignKey)): return True if method == 'create': if isinstance(field, AutoField): return True - if field.has_default(): + if isinstance(field, Field) and field.has_default(): return True return nullable - def get_field_set_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType: + def get_field_set_type(self, api: TypeChecker, field: Union[Field, ForeignObjectRel], *, method: str) -> MypyType: """ Get a type of __set__ for this specific Django field. """ target_field = field if isinstance(field, ForeignKey): @@ -259,7 +258,7 @@ class DjangoContext: field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type) return field_set_type - def get_field_get_type(self, api: TypeChecker, field: Field, *, method: str) -> MypyType: + def get_field_get_type(self, api: TypeChecker, field: Union[Field, ForeignObjectRel], *, method: str) -> MypyType: """ Get a type of __get__ for this specific Django field. """ field_info = helpers.lookup_class_typeinfo(api, field.__class__) if field_info is None: @@ -303,7 +302,10 @@ class DjangoContext: return related_model_cls - def _resolve_field_from_parts(self, field_parts: Iterable[str], model_cls: Type[Model]) -> Field: + def _resolve_field_from_parts(self, + field_parts: Iterable[str], + model_cls: Type[Model] + ) -> Union[Field, ForeignObjectRel]: currently_observed_model = model_cls field = None for field_part in field_parts: @@ -325,7 +327,7 @@ class DjangoContext: assert field is not None return field - def resolve_lookup_into_field(self, model_cls: Type[Model], lookup: str) -> Field: + def resolve_lookup_into_field(self, model_cls: Type[Model], lookup: str) -> Union[Field, ForeignObjectRel]: query = Query(model_cls) lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup) if lookup_parts: diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index 5c086ef..bffe38b 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -4,6 +4,7 @@ from typing import List, Optional, Sequence, Type, Union from django.core.exceptions import FieldError from django.db.models.base import Model from django.db.models.fields.related import RelatedField +from django.db.models.fields.reverse_related import ForeignObjectRel from mypy.nodes import Expression, NameExpr from mypy.plugin import FunctionContext, MethodContext from mypy.types import AnyType, Instance @@ -47,7 +48,8 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext except LookupsAreUnsupported: return AnyType(TypeOfAny.explicit) - if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup: + if ((isinstance(lookup_field, RelatedField) and lookup_field.column == lookup) + or isinstance(lookup_field, ForeignObjectRel)): related_model_cls = django_context.get_field_related_model_cls(lookup_field) if related_model_cls is None: return AnyType(TypeOfAny.from_error) diff --git a/test-data/typecheck/managers/querysets/test_values.yml b/test-data/typecheck/managers/querysets/test_values.yml index 420fcd0..36461c1 100644 --- a/test-data/typecheck/managers/querysets/test_values.yml +++ b/test-data/typecheck/managers/querysets/test_values.yml @@ -107,3 +107,20 @@ class Blog(models.Model): name = models.CharField(max_length=100) publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE) + +- case: values_of_many_to_many_field + main: | + from myapp.models import Author, Book + reveal_type(Book.objects.values('authors')) # N: Revealed type is 'django.db.models.query.ValuesQuerySet[myapp.models.Book, TypedDict({'authors': builtins.int})]' + reveal_type(Author.objects.values('books')) # N: Revealed type is 'django.db.models.query.ValuesQuerySet[myapp.models.Author, TypedDict({'books': builtins.int})]' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class Author(models.Model): + pass + class Book(models.Model): + authors = models.ManyToManyField(Author, related_name='books') diff --git a/test-data/typecheck/managers/querysets/test_values_list.yml b/test-data/typecheck/managers/querysets/test_values_list.yml index b7d5db4..c67f265 100644 --- a/test-data/typecheck/managers/querysets/test_values_list.yml +++ b/test-data/typecheck/managers/querysets/test_values_list.yml @@ -224,3 +224,20 @@ pass class Transaction(models.Model): total = models.IntegerField() + +- case: values_list_of_many_to_many_field + main: | + from myapp.models import Author, Book + reveal_type(Book.objects.values_list('authors')) # N: Revealed type is 'django.db.models.query.ValuesQuerySet[myapp.models.Book, Tuple[builtins.int]]' + reveal_type(Author.objects.values_list('books')) # N: Revealed type is 'django.db.models.query.ValuesQuerySet[myapp.models.Author, Tuple[builtins.int]]' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class Author(models.Model): + pass + class Book(models.Model): + authors = models.ManyToManyField(Author, related_name='books') \ No newline at end of file