mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-07 04:34:29 +08:00
allow to subclass queryset without loss of typing
This commit is contained in:
@@ -5,7 +5,7 @@ from mypy.types import Type as MypyType
|
||||
from mypy.types import TypeOfAny
|
||||
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
from mypy_django_plugin.lib import helpers
|
||||
|
||||
|
||||
def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType:
|
||||
@@ -20,21 +20,25 @@ def return_proper_field_type_from_get_field(ctx: MethodContext, django_context:
|
||||
# Options instance
|
||||
assert isinstance(ctx.type, Instance)
|
||||
|
||||
# bail if list of generic params is empty
|
||||
if len(ctx.type.args) == 0:
|
||||
return ctx.default_return_type
|
||||
|
||||
model_type = ctx.type.args[0]
|
||||
if not isinstance(model_type, Instance):
|
||||
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
||||
return ctx.default_return_type
|
||||
|
||||
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
||||
if model_cls is None:
|
||||
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
||||
return ctx.default_return_type
|
||||
|
||||
field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name')
|
||||
if field_name_expr is None:
|
||||
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
||||
return ctx.default_return_type
|
||||
|
||||
field_name = helpers.resolve_string_attribute_value(field_name_expr, ctx, django_context)
|
||||
if field_name is None:
|
||||
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
||||
return ctx.default_return_type
|
||||
|
||||
try:
|
||||
field = model_cls._meta.get_field(field_name)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List, Optional, Sequence, Type, Union, cast
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.fields.related import RelatedField
|
||||
from mypy.newsemanal.typeanal import TypeAnalyser
|
||||
from mypy.nodes import Expression, NameExpr, TypeInfo
|
||||
from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext
|
||||
@@ -10,11 +11,19 @@ from mypy.types import AnyType, Instance
|
||||
from mypy.types import Type as MypyType
|
||||
from mypy.types import TypeOfAny
|
||||
|
||||
from django.db.models.fields.related import RelatedField
|
||||
from mypy_django_plugin.django.context import DjangoContext
|
||||
from mypy_django_plugin.lib import fullnames, helpers
|
||||
|
||||
|
||||
def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]:
|
||||
for base_type in [queryset_type, *queryset_type.type.bases]:
|
||||
if (len(base_type.args)
|
||||
and isinstance(base_type.args[0], Instance)
|
||||
and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)):
|
||||
return base_type.args[0]
|
||||
return None
|
||||
|
||||
|
||||
def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
|
||||
default_return_type = ctx.default_return_type
|
||||
assert isinstance(default_return_type, Instance)
|
||||
@@ -98,11 +107,10 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
|
||||
assert isinstance(ctx.type, Instance)
|
||||
assert isinstance(ctx.default_return_type, Instance)
|
||||
|
||||
# bail if queryset of Any or other non-instances
|
||||
if not isinstance(ctx.type.args[0], Instance):
|
||||
model_type = _extract_model_type_from_queryset(ctx.type)
|
||||
if model_type is None:
|
||||
return AnyType(TypeOfAny.from_omitted_generics)
|
||||
|
||||
model_type = ctx.type.args[0]
|
||||
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
@@ -148,11 +156,10 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
|
||||
assert isinstance(ctx.type, Instance)
|
||||
assert isinstance(ctx.default_return_type, Instance)
|
||||
|
||||
# if queryset of non-instance type
|
||||
if not isinstance(ctx.type.args[0], Instance):
|
||||
model_type = _extract_model_type_from_queryset(ctx.type)
|
||||
if model_type is None:
|
||||
return AnyType(TypeOfAny.from_omitted_generics)
|
||||
|
||||
model_type = ctx.type.args[0]
|
||||
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
@@ -205,3 +205,21 @@
|
||||
pass
|
||||
class Blog(models.Model):
|
||||
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
|
||||
|
||||
- case: subclass_of_queryset_has_proper_typings_on_methods
|
||||
main: |
|
||||
from myapp.models import TransactionQuerySet
|
||||
reveal_type(TransactionQuerySet()) # N: Revealed type is 'myapp.models.TransactionQuerySet'
|
||||
reveal_type(TransactionQuerySet().values()) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.Transaction, TypedDict({'id': builtins.int, 'total': builtins.int})]'
|
||||
reveal_type(TransactionQuerySet().values_list()) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.Transaction, Tuple[builtins.int, builtins.int]]'
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
class TransactionQuerySet(models.QuerySet['Transaction']):
|
||||
pass
|
||||
class Transaction(models.Model):
|
||||
total = models.IntegerField()
|
||||
|
||||
Reference in New Issue
Block a user