mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-06 20:24:31 +08:00
add GenericForeignKey support, remove some false-positives
This commit is contained in:
@@ -8,8 +8,8 @@ from django.db.models.base import Model
|
||||
from .config import AppConfig
|
||||
|
||||
class Apps:
|
||||
all_models: 'Dict[str, OrderedDict[str, Type[Model]]]' = ...
|
||||
app_configs: 'OrderedDict[str, AppConfig]' = ...
|
||||
all_models: "Dict[str, OrderedDict[str, Type[Model]]]" = ...
|
||||
app_configs: "OrderedDict[str, AppConfig]" = ...
|
||||
stored_app_configs: List[Any] = ...
|
||||
apps_ready: bool = ...
|
||||
ready_event: threading.Event = ...
|
||||
|
||||
@@ -7,6 +7,7 @@ from django.db.models.fields.related import ForeignObject
|
||||
from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor
|
||||
from django.db.models.fields.reverse_related import ForeignObjectRel
|
||||
|
||||
from django.db.models.expressions import Combinable
|
||||
from django.db.models.fields import Field, PositiveIntegerField
|
||||
from django.db.models.fields.mixins import FieldCacheMixin
|
||||
from django.db.models.query import QuerySet
|
||||
@@ -14,6 +15,10 @@ from django.db.models.query_utils import FilteredRelation, PathInfo
|
||||
from django.db.models.sql.where import WhereNode
|
||||
|
||||
class GenericForeignKey(FieldCacheMixin):
|
||||
# django-stubs implementation only fields
|
||||
_pyi_private_set_type: Union[Any, Combinable]
|
||||
_pyi_private_get_type: Any
|
||||
# attributes
|
||||
auto_created: bool = ...
|
||||
concrete: bool = ...
|
||||
editable: bool = ...
|
||||
@@ -44,10 +49,8 @@ class GenericForeignKey(FieldCacheMixin):
|
||||
def get_prefetch_queryset(
|
||||
self, instances: Union[List[Model], QuerySet], queryset: Optional[QuerySet] = ...
|
||||
) -> Tuple[List[Model], Callable, Callable, bool, str, bool]: ...
|
||||
def __get__(
|
||||
self, instance: Optional[Model], cls: Type[Model] = ...
|
||||
) -> Optional[Union[GenericForeignKey, Model]]: ...
|
||||
def __set__(self, instance: Model, value: Optional[Model]) -> None: ...
|
||||
def __get__(self, instance: Optional[Model], cls: Type[Model] = ...) -> Optional[Any]: ...
|
||||
def __set__(self, instance: Model, value: Optional[Any]) -> None: ...
|
||||
|
||||
class GenericRel(ForeignObjectRel):
|
||||
field: GenericRelation
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from django.core.management.base import BaseCommand as BaseCommand, CommandError as CommandError
|
||||
from .base import BaseCommand as BaseCommand, CommandError as CommandError
|
||||
|
||||
def find_commands(management_dir: str) -> List[str]: ...
|
||||
def load_command_class(app_name: str, name: str) -> BaseCommand: ...
|
||||
def get_commands() -> Dict[str, str]: ...
|
||||
def call_command(command_name: Union[Tuple[str], BaseCommand, str], *args: Any, **options: Any) -> Optional[str]: ...
|
||||
def call_command(command_name: Union[Tuple[str], BaseCommand, str], *args: Any, **options: Any) -> str: ...
|
||||
|
||||
class ManagementUtility:
|
||||
argv: List[str] = ...
|
||||
|
||||
@@ -11,6 +11,7 @@ from .base import (
|
||||
SerializationError as SerializationError,
|
||||
DeserializationError as DeserializationError,
|
||||
M2MDeserializationError as M2MDeserializationError,
|
||||
DeserializedObject,
|
||||
)
|
||||
|
||||
BUILTIN_SERIALIZERS: Any
|
||||
@@ -27,10 +28,8 @@ def get_serializer(format: str) -> Union[Type[Serializer], BadSerializer]: ...
|
||||
def get_serializer_formats() -> List[str]: ...
|
||||
def get_public_serializer_formats() -> List[str]: ...
|
||||
def get_deserializer(format: str) -> Union[Callable, Type[Deserializer]]: ...
|
||||
def serialize(
|
||||
format: str, queryset: Union[Iterator[Any], List[Model], QuerySet], **options: Any
|
||||
) -> Optional[Union[bytes, str]]: ...
|
||||
def deserialize(format: str, stream_or_string: Any, **options: Any) -> Union[Iterator[Any], Deserializer]: ...
|
||||
def serialize(format: str, queryset: Iterable[Model], **options: Any) -> Optional[Union[bytes, str]]: ...
|
||||
def deserialize(format: str, stream_or_string: Any, **options: Any) -> Iterator[DeserializedObject]: ...
|
||||
def sort_dependencies(
|
||||
app_list: Union[Iterable[Tuple[AppConfig, None]], Iterable[Tuple[str, Iterable[Type[Model]]]]]
|
||||
) -> List[Type[Model]]: ...
|
||||
|
||||
@@ -70,8 +70,8 @@ class Deserializer:
|
||||
def __next__(self) -> None: ...
|
||||
|
||||
class DeserializedObject:
|
||||
object: Model = ...
|
||||
m2m_data: Dict[Any, Any] = ...
|
||||
object: Any = ...
|
||||
m2m_data: Dict[str, List[int]] = ...
|
||||
def __init__(self, obj: Model, m2m_data: Optional[Dict[str, List[int]]] = ...) -> None: ...
|
||||
def save(self, save_m2m: bool = ..., using: Optional[str] = ..., **kwargs: Any) -> None: ...
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from django.core.checks.messages import CheckMessage
|
||||
|
||||
from django.db.models.options import Options
|
||||
|
||||
|
||||
class ModelBase(type): ...
|
||||
|
||||
_Self = TypeVar("_Self", bound="Model")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import decimal
|
||||
import uuid
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from typing import Any, Callable, Dict, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union, Sequence
|
||||
from typing import Any, Callable, Dict, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union, Sequence, List
|
||||
|
||||
from django.core import checks
|
||||
|
||||
from django.db.models import Model
|
||||
from django.core.exceptions import FieldDoesNotExist as FieldDoesNotExist
|
||||
@@ -41,6 +43,8 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
|
||||
null: bool = ...
|
||||
editable: bool = ...
|
||||
choices: Optional[_FieldChoices] = ...
|
||||
db_column: Optional[str]
|
||||
column: str
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name: Optional[Union[str, bytes]] = ...,
|
||||
@@ -86,6 +90,7 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
|
||||
) -> Sequence[Union[_Choice, _ChoiceNamedGroup]]: ...
|
||||
def has_default(self) -> bool: ...
|
||||
def get_default(self) -> Any: ...
|
||||
def check(self, **kwargs: Any) -> List[checks.Error]: ...
|
||||
|
||||
class IntegerField(Field[_ST, _GT]):
|
||||
_pyi_private_set_type: Union[float, int, str, Combinable]
|
||||
|
||||
@@ -108,4 +108,6 @@ class Options:
|
||||
def get_ancestor_link(self, ancestor: Type[Model]) -> Optional[OneToOneField]: ...
|
||||
def get_path_to_parent(self, parent: Type[Model]) -> List[PathInfo]: ...
|
||||
def get_path_from_parent(self, parent: Type[Model]) -> List[PathInfo]: ...
|
||||
def get_fields(self, include_parents: bool = ..., include_hidden: bool = ...) -> List[Union[Field, ForeignObjectRel]]: ...
|
||||
def get_fields(
|
||||
self, include_parents: bool = ..., include_hidden: bool = ...
|
||||
) -> List[Union[Field, ForeignObjectRel]]: ...
|
||||
|
||||
@@ -121,9 +121,11 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
|
||||
def select_for_update(self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...) -> QuerySet[_T, _Row]: ...
|
||||
def select_related(self, *fields: Any) -> QuerySet[_T, _Row]: ...
|
||||
def prefetch_related(self, *lookups: Any) -> QuerySet[_T, _Row]: ...
|
||||
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T, _Row]: ...
|
||||
# TODO: return type
|
||||
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[Any, Any]: ...
|
||||
def order_by(self, *field_names: Any) -> QuerySet[_T, _Row]: ...
|
||||
def distinct(self, *field_names: Any) -> QuerySet[_T, _Row]: ...
|
||||
# extra() return type won't be supported any time soon
|
||||
def extra(
|
||||
self,
|
||||
select: Optional[Dict[str, Any]] = ...,
|
||||
@@ -132,7 +134,7 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized):
|
||||
tables: Optional[List[str]] = ...,
|
||||
order_by: Optional[Sequence[str]] = ...,
|
||||
select_params: Optional[Sequence[Any]] = ...,
|
||||
) -> QuerySet[_T, _Row]: ...
|
||||
) -> QuerySet[Any, Any]: ...
|
||||
def reverse(self) -> QuerySet[_T, _Row]: ...
|
||||
def defer(self, *fields: Any) -> QuerySet[_T, _Row]: ...
|
||||
def only(self, *fields: Any) -> QuerySet[_T, _Row]: ...
|
||||
|
||||
@@ -39,10 +39,10 @@ ungettext = ngettext
|
||||
def pgettext(context: str, message: str) -> str: ...
|
||||
def npgettext(context: str, singular: str, plural: str, number: int) -> str: ...
|
||||
|
||||
gettext_lazy: Any
|
||||
gettext_lazy: Callable[[str], str]
|
||||
|
||||
ugettext_lazy: Any
|
||||
pgettext_lazy: Any
|
||||
ugettext_lazy: Callable[[str], str]
|
||||
pgettext_lazy: Callable[[str], str]
|
||||
|
||||
def ngettext_lazy(singular: Any, plural: Any, number: Optional[Any] = ...): ...
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type, Sequence
|
||||
from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type
|
||||
|
||||
from django.core.exceptions import FieldError, FieldDoesNotExist
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.fields.related import ForeignKey, RelatedField
|
||||
from django.db.models.fields.reverse_related import ForeignObjectRel
|
||||
from django.db.models.sql.query import Query
|
||||
from django.utils.functional import cached_property
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.types import Instance, Type as MypyType
|
||||
@@ -13,9 +15,6 @@ from pytest_mypy.utils import temp_environ
|
||||
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.db.models.fields import CharField, Field
|
||||
from django.db.models.fields.reverse_related import ForeignObjectRel, ManyToOneRel, ManyToManyRel
|
||||
|
||||
from django.db.models.sql.query import Query
|
||||
from mypy_django_plugin.lib import helpers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -184,14 +183,16 @@ class DjangoContext:
|
||||
raise ValueError('No primary key defined')
|
||||
|
||||
def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], method: str) -> Dict[str, MypyType]:
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
|
||||
expected_types = {}
|
||||
if method == '__init__':
|
||||
# add pk
|
||||
primary_key_field = self.get_primary_key_field(model_cls)
|
||||
field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method)
|
||||
expected_types['pk'] = field_set_type
|
||||
|
||||
for field in self.get_model_fields(model_cls):
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, Field):
|
||||
field_name = field.attname
|
||||
field_set_type = self.fields_context.get_field_set_type(api, field, method)
|
||||
expected_types[field_name] = field_set_type
|
||||
@@ -207,4 +208,13 @@ class DjangoContext:
|
||||
model_set_type = helpers.convert_any_to_type(foreign_key_set_type,
|
||||
Instance(related_model_info, []))
|
||||
expected_types[field_name] = model_set_type
|
||||
|
||||
elif isinstance(field, GenericForeignKey):
|
||||
# it's generic, so cannot set specific model
|
||||
field_name = field.name
|
||||
gfk_info = helpers.lookup_class_typeinfo(api, field.__class__)
|
||||
gfk_set_type = helpers.get_private_descriptor_type(gfk_info, '_pyi_private_set_type',
|
||||
is_nullable=True)
|
||||
expected_types[field_name] = gfk_set_type
|
||||
|
||||
return expected_types
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional, Tuple, cast
|
||||
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.nodes import MypyFile, TypeInfo
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import AnyType, CallableType, Instance, Type as MypyType, TypeOfAny
|
||||
|
||||
@@ -14,11 +14,13 @@ def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoC
|
||||
assert isinstance(to_arg_type.ret_type, Instance)
|
||||
return to_arg_type.ret_type.type.fullname()
|
||||
|
||||
outer_model_info = ctx.api.tscope.classes[-1]
|
||||
outer_model_info = ctx.api.scope.active_class()
|
||||
if not outer_model_info or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
|
||||
# not inside models.Model class
|
||||
return None
|
||||
assert isinstance(outer_model_info, TypeInfo)
|
||||
|
||||
to_arg_expr = helpers.get_call_argument_by_name(ctx, 'to')
|
||||
|
||||
model_string = helpers.resolve_string_attribute_value(to_arg_expr, ctx, django_context)
|
||||
if model_string is None:
|
||||
# unresolvable
|
||||
@@ -28,10 +30,21 @@ def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoC
|
||||
return outer_model_info.fullname()
|
||||
if '.' not in model_string:
|
||||
# same file class
|
||||
current_module = ctx.api.tree
|
||||
if model_string not in current_module.names:
|
||||
model_cls_is_accessible = False
|
||||
for scope in ctx.api.scope.stack:
|
||||
if isinstance(scope, (MypyFile, TypeInfo)):
|
||||
model_class_candidate = scope.names.get(model_string)
|
||||
model_cls_is_accessible = (model_class_candidate is not None
|
||||
and isinstance(model_class_candidate.node, TypeInfo)
|
||||
and model_class_candidate.node.has_base(fullnames.MODEL_CLASS_FULLNAME))
|
||||
if model_cls_is_accessible:
|
||||
break
|
||||
# TODO: FuncItem
|
||||
|
||||
if not model_cls_is_accessible:
|
||||
ctx.api.fail(f'No model {model_string!r} defined in the current module', ctx.context)
|
||||
return None
|
||||
|
||||
return outer_model_info.module_name + '.' + model_string
|
||||
|
||||
app_label, model_name = model_string.split('.')
|
||||
@@ -88,12 +101,6 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan
|
||||
default_return_type = ctx.default_return_type
|
||||
assert isinstance(default_return_type, Instance)
|
||||
|
||||
# bail out if we're inside migration, not supported yet
|
||||
active_class = ctx.api.scope.active_class()
|
||||
if active_class is not None:
|
||||
if active_class.has_base(fullnames.MIGRATION_CLASS_FULLNAME):
|
||||
return ctx.default_return_type
|
||||
|
||||
if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):
|
||||
return fill_descriptor_types_for_related_field(ctx, django_context)
|
||||
|
||||
|
||||
22
test-data/typecheck/fields/test_generic_foreign_key.yml
Normal file
22
test-data/typecheck/fields/test_generic_foreign_key.yml
Normal file
@@ -0,0 +1,22 @@
|
||||
- case: generic_foreign_key_could_point_to_any_model_and_is_always_optional
|
||||
main: |
|
||||
from myapp.models import Tag, User
|
||||
myuser = User()
|
||||
Tag(content_object=None)
|
||||
Tag(content_object=myuser)
|
||||
Tag.objects.create(content_object=None)
|
||||
Tag.objects.create(content_object=myuser)
|
||||
reveal_type(Tag().content_object) # N: Revealed type is 'Union[Any, None]'
|
||||
installed_apps:
|
||||
- django.contrib.contenttypes
|
||||
- myapp
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
from django.contrib.contenttypes import fields
|
||||
class User(models.Model):
|
||||
pass
|
||||
class Tag(models.Model):
|
||||
content_object = fields.GenericForeignKey()
|
||||
@@ -381,10 +381,10 @@
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
class Book(models.Model):
|
||||
publisher = models.ForeignKey(to='Publisher', on_delete=models.CASCADE)
|
||||
class Publisher(models.Model):
|
||||
pass
|
||||
class Book(models.Model):
|
||||
publisher = models.ForeignKey(to='Publisher', on_delete=models.CASCADE)
|
||||
|
||||
- case: test_foreign_key_field_without_backwards_relation
|
||||
main: |
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
- case: default_manager_create_is_typechecked
|
||||
main: |
|
||||
from myapp.models import User
|
||||
User.objects.create(name='Max', age=10)
|
||||
User.objects.create(pk=1, name='Max', age=10)
|
||||
User.objects.create(age=[]) # E: Incompatible type for "age" of "User" (got "List[Any]", expected "Union[float, int, str, Combinable]")
|
||||
installed_apps:
|
||||
- myapp
|
||||
|
||||
Reference in New Issue
Block a user