From 409c01eb2489b1f2dda646786584a7637ee2bead Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Thu, 25 Jul 2019 19:22:59 +0300 Subject: [PATCH] allow to specify QuerySet with one parameter --- mypy_django_plugin/django/context.py | 14 +++++------ mypy_django_plugin/main.py | 13 +++++++--- mypy_django_plugin/transformers/querysets.py | 24 ++++++++++++++++--- mypy_django_plugin/transformers/settings.py | 7 +----- .../managers/querysets/test_basic_methods.yml | 9 ++++++- 5 files changed, 46 insertions(+), 21 deletions(-) diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index da2618a..4b7500f 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -1,27 +1,25 @@ import os from collections import defaultdict from contextlib import contextmanager -from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Type +from typing import Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type -from mypy.nodes import TypeInfo - -from django.contrib.postgres.fields import ArrayField from django.core.exceptions import FieldError from django.db.models.base import Model -from django.db.models.fields import AutoField, CharField, Field from django.db.models.fields.related import ForeignKey, RelatedField from django.db.models.fields.reverse_related import ForeignObjectRel from django.db.models.sql.query import Query from django.utils.functional import cached_property from mypy.checker import TypeChecker -from mypy.types import Instance, AnyType, TypeOfAny -from mypy.types import Type as MypyType +from mypy.nodes import TypeInfo +from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny +from django.contrib.postgres.fields import ArrayField +from django.db.models.fields import AutoField, CharField, Field from mypy_django_plugin.lib import helpers if TYPE_CHECKING: from django.apps.registry import Apps # noqa: F401 - from django.conf import LazySettings + from django.conf import LazySettings # noqa: F401 @contextmanager diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 2c76bf1..4f2fd49 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -6,9 +6,7 @@ from django.db.models.fields.related import RelatedField from mypy.errors import Errors from mypy.nodes import MypyFile, TypeInfo from mypy.options import Options -from mypy.plugin import ( - AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin, -) +from mypy.plugin import AnalyzeTypeContext, AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin from mypy.types import Type as MypyType from mypy_django_plugin.django.context import DjangoContext @@ -233,6 +231,15 @@ class NewSemanalDjangoPlugin(Plugin): return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context) return None + def get_type_analyze_hook(self, fullname: str + ) -> Optional[Callable[[AnalyzeTypeContext], MypyType]]: + info = self._get_typeinfo_or_none(fullname) + if (info + and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) + and not info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME)): + return partial(querysets.set_first_generic_param_as_default_for_second, fullname=fullname) + return None + def plugin(version): return NewSemanalDjangoPlugin diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index 5b6d1b9..b1bdce3 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -1,10 +1,11 @@ from collections import OrderedDict -from typing import List, Optional, Sequence, Type, Union +from typing import List, Optional, Sequence, Type, Union, cast from django.core.exceptions import FieldError from django.db.models.base import Model -from mypy.nodes import Expression, NameExpr -from mypy.plugin import FunctionContext, MethodContext +from mypy.newsemanal.typeanal import TypeAnalyser +from mypy.nodes import Expression, NameExpr, TypeInfo +from mypy.plugin import FunctionContext, MethodContext, AnalyzeTypeContext from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny from mypy_django_plugin.django.context import DjangoContext @@ -169,3 +170,20 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys())) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) + + +def set_first_generic_param_as_default_for_second(ctx: AnalyzeTypeContext, fullname: str) -> MypyType: + type_analyser_api = cast(TypeAnalyser, ctx.api) + + info = helpers.lookup_fully_qualified_typeinfo(type_analyser_api.api, fullname) # type: ignore + assert isinstance(info, TypeInfo) + + if not ctx.type.args: + return Instance(info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)]) + + args = ctx.type.args + if len(args) == 1: + args = [args[0], args[0]] + + analyzed_args = [type_analyser_api.analyze_type(arg) for arg in args] + return Instance(info, analyzed_args) diff --git a/mypy_django_plugin/transformers/settings.py b/mypy_django_plugin/transformers/settings.py index f74f213..e8ec089 100644 --- a/mypy_django_plugin/transformers/settings.py +++ b/mypy_django_plugin/transformers/settings.py @@ -1,11 +1,6 @@ -from typing import cast - -from mypy.checker import TypeChecker from mypy.nodes import MemberExpr, TypeInfo from mypy.plugin import AttributeContext, FunctionContext -from mypy.types import Instance -from mypy.types import Type as MypyType -from mypy.types import TypeType +from mypy.types import Instance, Type as MypyType, TypeType from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.lib import helpers diff --git a/test-data/typecheck/managers/querysets/test_basic_methods.yml b/test-data/typecheck/managers/querysets/test_basic_methods.yml index a6741f4..64b4b49 100644 --- a/test-data/typecheck/managers/querysets/test_basic_methods.yml +++ b/test-data/typecheck/managers/querysets/test_basic_methods.yml @@ -28,4 +28,11 @@ from django.db import models class Blog(models.Model): - created_at = models.DateTimeField() \ No newline at end of file + created_at = models.DateTimeField() + +- case: queryset_could_be_specified_with_one_type + main: | + from typing import Optional + from django.db import models + queryset: models.QuerySet[models.Model] = models.QuerySet() + reveal_type(queryset) # N: Revealed type is 'django.db.models.query.QuerySet[django.db.models.base.Model, django.db.models.base.Model]'