mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-07 20:54:29 +08:00
Merge pull request #119 from mkurnikov/subclass-queryset-proper-typing
Allow to subclass queryset without loss of typing
This commit is contained in:
@@ -5,7 +5,7 @@ from mypy.types import Type as MypyType
|
|||||||
from mypy.types import TypeOfAny
|
from mypy.types import TypeOfAny
|
||||||
|
|
||||||
from mypy_django_plugin.django.context import DjangoContext
|
from mypy_django_plugin.django.context import DjangoContext
|
||||||
from mypy_django_plugin.lib import fullnames, helpers
|
from mypy_django_plugin.lib import helpers
|
||||||
|
|
||||||
|
|
||||||
def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType:
|
def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType:
|
||||||
@@ -20,21 +20,25 @@ def return_proper_field_type_from_get_field(ctx: MethodContext, django_context:
|
|||||||
# Options instance
|
# Options instance
|
||||||
assert isinstance(ctx.type, Instance)
|
assert isinstance(ctx.type, Instance)
|
||||||
|
|
||||||
|
# bail if list of generic params is empty
|
||||||
|
if len(ctx.type.args) == 0:
|
||||||
|
return ctx.default_return_type
|
||||||
|
|
||||||
model_type = ctx.type.args[0]
|
model_type = ctx.type.args[0]
|
||||||
if not isinstance(model_type, Instance):
|
if not isinstance(model_type, Instance):
|
||||||
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
return ctx.default_return_type
|
||||||
|
|
||||||
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
||||||
if model_cls is None:
|
if model_cls is None:
|
||||||
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
return ctx.default_return_type
|
||||||
|
|
||||||
field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name')
|
field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name')
|
||||||
if field_name_expr is None:
|
if field_name_expr is None:
|
||||||
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
return ctx.default_return_type
|
||||||
|
|
||||||
field_name = helpers.resolve_string_attribute_value(field_name_expr, ctx, django_context)
|
field_name = helpers.resolve_string_attribute_value(field_name_expr, ctx, django_context)
|
||||||
if field_name is None:
|
if field_name is None:
|
||||||
return _get_field_instance(ctx, fullnames.FIELD_FULLNAME)
|
return ctx.default_return_type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
field = model_cls._meta.get_field(field_name)
|
field = model_cls._meta.get_field(field_name)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import List, Optional, Sequence, Type, Union, cast
|
|||||||
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db.models.base import Model
|
from django.db.models.base import Model
|
||||||
|
from django.db.models.fields.related import RelatedField
|
||||||
from mypy.newsemanal.typeanal import TypeAnalyser
|
from mypy.newsemanal.typeanal import TypeAnalyser
|
||||||
from mypy.nodes import Expression, NameExpr, TypeInfo
|
from mypy.nodes import Expression, NameExpr, TypeInfo
|
||||||
from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext
|
from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext
|
||||||
@@ -10,11 +11,19 @@ from mypy.types import AnyType, Instance
|
|||||||
from mypy.types import Type as MypyType
|
from mypy.types import Type as MypyType
|
||||||
from mypy.types import TypeOfAny
|
from mypy.types import TypeOfAny
|
||||||
|
|
||||||
from django.db.models.fields.related import RelatedField
|
|
||||||
from mypy_django_plugin.django.context import DjangoContext
|
from mypy_django_plugin.django.context import DjangoContext
|
||||||
from mypy_django_plugin.lib import fullnames, helpers
|
from mypy_django_plugin.lib import fullnames, helpers
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]:
|
||||||
|
for base_type in [queryset_type, *queryset_type.type.bases]:
|
||||||
|
if (len(base_type.args)
|
||||||
|
and isinstance(base_type.args[0], Instance)
|
||||||
|
and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)):
|
||||||
|
return base_type.args[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
|
def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
|
||||||
default_return_type = ctx.default_return_type
|
default_return_type = ctx.default_return_type
|
||||||
assert isinstance(default_return_type, Instance)
|
assert isinstance(default_return_type, Instance)
|
||||||
@@ -98,11 +107,10 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
|
|||||||
assert isinstance(ctx.type, Instance)
|
assert isinstance(ctx.type, Instance)
|
||||||
assert isinstance(ctx.default_return_type, Instance)
|
assert isinstance(ctx.default_return_type, Instance)
|
||||||
|
|
||||||
# bail if queryset of Any or other non-instances
|
model_type = _extract_model_type_from_queryset(ctx.type)
|
||||||
if not isinstance(ctx.type.args[0], Instance):
|
if model_type is None:
|
||||||
return AnyType(TypeOfAny.from_omitted_generics)
|
return AnyType(TypeOfAny.from_omitted_generics)
|
||||||
|
|
||||||
model_type = ctx.type.args[0]
|
|
||||||
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
||||||
if model_cls is None:
|
if model_cls is None:
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
@@ -148,11 +156,10 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
|
|||||||
assert isinstance(ctx.type, Instance)
|
assert isinstance(ctx.type, Instance)
|
||||||
assert isinstance(ctx.default_return_type, Instance)
|
assert isinstance(ctx.default_return_type, Instance)
|
||||||
|
|
||||||
# if queryset of non-instance type
|
model_type = _extract_model_type_from_queryset(ctx.type)
|
||||||
if not isinstance(ctx.type.args[0], Instance):
|
if model_type is None:
|
||||||
return AnyType(TypeOfAny.from_omitted_generics)
|
return AnyType(TypeOfAny.from_omitted_generics)
|
||||||
|
|
||||||
model_type = ctx.type.args[0]
|
|
||||||
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
|
||||||
if model_cls is None:
|
if model_cls is None:
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -28,7 +28,7 @@ dependencies = [
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="django-stubs",
|
name="django-stubs",
|
||||||
version="1.0.1",
|
version="1.0.2",
|
||||||
description='Mypy stubs for Django',
|
description='Mypy stubs for Django',
|
||||||
long_description=readme,
|
long_description=readme,
|
||||||
long_description_content_type='text/markdown',
|
long_description_content_type='text/markdown',
|
||||||
|
|||||||
@@ -205,3 +205,21 @@
|
|||||||
pass
|
pass
|
||||||
class Blog(models.Model):
|
class Blog(models.Model):
|
||||||
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
|
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
- case: subclass_of_queryset_has_proper_typings_on_methods
|
||||||
|
main: |
|
||||||
|
from myapp.models import TransactionQuerySet
|
||||||
|
reveal_type(TransactionQuerySet()) # N: Revealed type is 'myapp.models.TransactionQuerySet'
|
||||||
|
reveal_type(TransactionQuerySet().values()) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.Transaction, TypedDict({'id': builtins.int, 'total': builtins.int})]'
|
||||||
|
reveal_type(TransactionQuerySet().values_list()) # N: Revealed type is 'django.db.models.query.QuerySet[myapp.models.Transaction, Tuple[builtins.int, builtins.int]]'
|
||||||
|
installed_apps:
|
||||||
|
- myapp
|
||||||
|
files:
|
||||||
|
- path: myapp/__init__.py
|
||||||
|
- path: myapp/models.py
|
||||||
|
content: |
|
||||||
|
from django.db import models
|
||||||
|
class TransactionQuerySet(models.QuerySet['Transaction']):
|
||||||
|
pass
|
||||||
|
class Transaction(models.Model):
|
||||||
|
total = models.IntegerField()
|
||||||
|
|||||||
Reference in New Issue
Block a user