mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-08 21:14:49 +08:00
add related_name manager to ForeignKey model
This commit is contained in:
@@ -9,3 +9,4 @@ from .fields import (AutoField as AutoField,
|
||||
TextField as TextField)
|
||||
from .fields.related import (ForeignKey as ForeignKey)
|
||||
from .deletion import CASCADE as CASCADE
|
||||
from .query import QuerySet as QuerySet
|
||||
177
django-stubs/db/models/query.pyi
Normal file
177
django-stubs/db/models/query.pyi
Normal file
@@ -0,0 +1,177 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import date, datetime
|
||||
from typing import Generic, TypeVar, Optional, Any, Type, Dict, Union, overload, List, Iterator, Tuple, Callable
|
||||
|
||||
from django.db import models
|
||||
|
||||
_T = TypeVar('_T', bound=models.Model)
|
||||
|
||||
|
||||
class QuerySet(Generic[_T]):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Type[models.Model]] = ...,
|
||||
query: Optional[Any] = ...,
|
||||
using: Optional[str] = ...,
|
||||
hints: Optional[Dict[str, models.Model]] = ...,
|
||||
) -> None: ...
|
||||
@classmethod
|
||||
def as_manager(cls): ...
|
||||
def __len__(self) -> int: ...
|
||||
def __iter__(self) -> Iterator[_T]: ...
|
||||
def __bool__(self) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self, k: slice) -> QuerySet[_T]: ...
|
||||
@overload
|
||||
def __getitem__(self, k: int) -> _T: ...
|
||||
@overload
|
||||
def __getitem__(self, k: str) -> Any: ...
|
||||
def __and__(self, other: QuerySet) -> QuerySet: ...
|
||||
def __or__(self, other: QuerySet) -> QuerySet: ...
|
||||
|
||||
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
|
||||
def aggregate(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> Dict[str, Optional[Union[datetime, float]]]: ...
|
||||
def count(self) -> int: ...
|
||||
def get(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> _T: ...
|
||||
def create(self, **kwargs: Any) -> _T: ...
|
||||
def bulk_create(
|
||||
self,
|
||||
objs: Union[Iterator[Any], List[models.Model]],
|
||||
batch_size: Optional[int] = ...,
|
||||
) -> List[_T]: ...
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
defaults: Optional[Union[Dict[str, date], Dict[str, models.Model]]] = ...,
|
||||
**kwargs: Any
|
||||
) -> Tuple[_T, bool]: ...
|
||||
|
||||
def update_or_create(
|
||||
self,
|
||||
defaults: Optional[
|
||||
Union[
|
||||
Dict[str, Callable],
|
||||
Dict[str, date],
|
||||
Dict[str, models.Model],
|
||||
Dict[str, str],
|
||||
]
|
||||
] = ...,
|
||||
**kwargs: Any
|
||||
) -> Tuple[_T, bool]: ...
|
||||
|
||||
def earliest(
|
||||
self, *fields: Any, field_name: Optional[Any] = ...
|
||||
) -> _T: ...
|
||||
|
||||
def latest(
|
||||
self, *fields: Any, field_name: Optional[Any] = ...
|
||||
) -> _T: ...
|
||||
|
||||
def first(self) -> Optional[Union[Dict[str, int], _T]]: ...
|
||||
|
||||
def last(self) -> Optional[_T]: ...
|
||||
|
||||
def in_bulk(
|
||||
self, id_list: Any = ..., *, field_name: str = ...
|
||||
) -> Union[Dict[int, models.Model], Dict[str, models.Model]]: ...
|
||||
|
||||
def delete(self) -> Tuple[int, Dict[str, int]]: ...
|
||||
|
||||
def update(self, **kwargs: Any) -> int: ...
|
||||
|
||||
def exists(self) -> bool: ...
|
||||
|
||||
def explain(
|
||||
self, *, format: Optional[Any] = ..., **options: Any
|
||||
) -> str: ...
|
||||
|
||||
def raw(
|
||||
self,
|
||||
raw_query: str,
|
||||
params: Any = ...,
|
||||
translations: Optional[Dict[str, str]] = ...,
|
||||
using: None = ...,
|
||||
) -> RawQuerySet: ...
|
||||
|
||||
def values(self, *fields: Any, **expressions: Any) -> QuerySet: ...
|
||||
|
||||
def values_list(
|
||||
self, *fields: Any, flat: bool = ..., named: bool = ...
|
||||
) -> QuerySet: ...
|
||||
|
||||
def dates(
|
||||
self, field_name: str, kind: str, order: str = ...
|
||||
) -> QuerySet: ...
|
||||
|
||||
def datetimes(
|
||||
self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...
|
||||
) -> QuerySet: ...
|
||||
|
||||
def none(self) -> QuerySet[_T]: ...
|
||||
|
||||
def all(self) -> QuerySet[_T]: ...
|
||||
|
||||
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def complex_filter(
|
||||
self,
|
||||
filter_obj: Any,
|
||||
) -> QuerySet[_T]: ...
|
||||
|
||||
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
|
||||
|
||||
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def select_for_update(
|
||||
self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...
|
||||
) -> QuerySet: ...
|
||||
|
||||
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def extra(
|
||||
self,
|
||||
select: Optional[
|
||||
Union[Dict[str, int], Dict[str, str], OrderedDict]
|
||||
] = ...,
|
||||
where: Optional[List[str]] = ...,
|
||||
params: Optional[Union[List[int], List[str]]] = ...,
|
||||
tables: Optional[List[str]] = ...,
|
||||
order_by: Optional[Union[List[str], Tuple[str]]] = ...,
|
||||
select_params: Optional[Union[List[int], List[str], Tuple[int]]] = ...,
|
||||
) -> QuerySet[_T]: ...
|
||||
|
||||
def reverse(self) -> QuerySet[_T]: ...
|
||||
|
||||
def defer(self, *fields: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def only(self, *fields: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def using(self, alias: Optional[str]) -> QuerySet[_T]: ...
|
||||
|
||||
@property
|
||||
def ordered(self) -> bool: ...
|
||||
|
||||
@property
|
||||
def db(self) -> str: ...
|
||||
|
||||
def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
class RawQuerySet:
|
||||
pass
|
||||
40
mypy_django_plugin/plugins/related_fields.py
Normal file
40
mypy_django_plugin/plugins/related_fields.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Optional, Callable
|
||||
|
||||
from mypy.nodes import SymbolTableNode, MDEF, Var
|
||||
from mypy.plugin import Plugin, FunctionContext
|
||||
from mypy.types import Type, CallableType, TypeOfAny, AnyType, Instance
|
||||
|
||||
|
||||
def set_related_fields(ctx: FunctionContext) -> Type:
|
||||
if 'related_name' not in ctx.context.arg_names:
|
||||
return ctx.default_return_type
|
||||
|
||||
assert 'to' in ctx.context.arg_names
|
||||
to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0]
|
||||
if not isinstance(to_arg_value, CallableType):
|
||||
return ctx.default_return_type
|
||||
|
||||
referred_to = to_arg_value.ret_type
|
||||
related_name = ctx.context.args[ctx.context.arg_names.index('related_name')].value
|
||||
outer_class_info = ctx.api.tscope.classes[-1]
|
||||
|
||||
queryset_type = ctx.api.named_generic_type('django.db.models.QuerySet',
|
||||
args=[Instance(outer_class_info, [])])
|
||||
related_var = Var(related_name,
|
||||
queryset_type)
|
||||
related_var.info = queryset_type.type
|
||||
|
||||
referred_to.type.names[related_name] = SymbolTableNode(MDEF, related_var)
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
class RelatedFieldsPlugin(Plugin):
|
||||
def get_function_hook(self, fullname: str
|
||||
) -> Optional[Callable[[FunctionContext], Type]]:
|
||||
if fullname == 'django.db.models.fields.related.ForeignKey':
|
||||
return set_related_fields
|
||||
return None
|
||||
|
||||
|
||||
def plugin(version):
|
||||
return RelatedFieldsPlugin
|
||||
@@ -1,2 +1,4 @@
|
||||
[mypy]
|
||||
plugins = mypy_django_plugin.plugins.postgres_fields
|
||||
plugins =
|
||||
mypy_django_plugin.plugins.postgres_fields,
|
||||
mypy_django_plugin.plugins.related_fields
|
||||
@@ -1,14 +1,27 @@
|
||||
[case testForeignKeyWithClass]
|
||||
from django.db import models
|
||||
|
||||
class User(models.Model):
|
||||
class Publisher(models.Model):
|
||||
pass
|
||||
|
||||
class Profile(models.Model):
|
||||
user = models.ForeignKey(to=User, on_delete=models.CASCADE)
|
||||
class Book(models.Model):
|
||||
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
|
||||
|
||||
profile = Profile()
|
||||
reveal_type(profile.user) # E: Revealed type is 'main.User*'
|
||||
book = Book()
|
||||
reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*'
|
||||
[out]
|
||||
|
||||
|
||||
[case testForeignKeyRelatedName]
|
||||
from django.db import models
|
||||
|
||||
class Publisher(models.Model):
|
||||
pass
|
||||
|
||||
class Book(models.Model):
|
||||
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE,
|
||||
related_name='books')
|
||||
|
||||
publisher = Publisher()
|
||||
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
|
||||
[out]
|
||||
|
||||
Reference in New Issue
Block a user