add related_name manager to ForeignKey model

This commit is contained in:
Maxim Kurnikov
2018-11-11 02:08:05 +03:00
parent c32a7842d6
commit 2a16bb04e5
5 changed files with 239 additions and 6 deletions

View File

@@ -9,3 +9,4 @@ from .fields import (AutoField as AutoField,
TextField as TextField) TextField as TextField)
from .fields.related import (ForeignKey as ForeignKey) from .fields.related import (ForeignKey as ForeignKey)
from .deletion import CASCADE as CASCADE from .deletion import CASCADE as CASCADE
from .query import QuerySet as QuerySet

View File

@@ -0,0 +1,177 @@
from collections import OrderedDict
from datetime import date, datetime
from typing import Generic, TypeVar, Optional, Any, Type, Dict, Union, overload, List, Iterator, Tuple, Callable
from django.db import models
_T = TypeVar('_T', bound=models.Model)
class QuerySet(Generic[_T]):
def __init__(
self,
model: Optional[Type[models.Model]] = ...,
query: Optional[Any] = ...,
using: Optional[str] = ...,
hints: Optional[Dict[str, models.Model]] = ...,
) -> None: ...
@classmethod
def as_manager(cls): ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T]: ...
def __bool__(self) -> bool: ...
@overload
def __getitem__(self, k: slice) -> QuerySet[_T]: ...
@overload
def __getitem__(self, k: int) -> _T: ...
@overload
def __getitem__(self, k: str) -> Any: ...
def __and__(self, other: QuerySet) -> QuerySet: ...
def __or__(self, other: QuerySet) -> QuerySet: ...
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
def aggregate(
self, *args: Any, **kwargs: Any
) -> Dict[str, Optional[Union[datetime, float]]]: ...
def count(self) -> int: ...
def get(
self, *args: Any, **kwargs: Any
) -> _T: ...
def create(self, **kwargs: Any) -> _T: ...
def bulk_create(
self,
objs: Union[Iterator[Any], List[models.Model]],
batch_size: Optional[int] = ...,
) -> List[_T]: ...
def get_or_create(
self,
defaults: Optional[Union[Dict[str, date], Dict[str, models.Model]]] = ...,
**kwargs: Any
) -> Tuple[_T, bool]: ...
def update_or_create(
self,
defaults: Optional[
Union[
Dict[str, Callable],
Dict[str, date],
Dict[str, models.Model],
Dict[str, str],
]
] = ...,
**kwargs: Any
) -> Tuple[_T, bool]: ...
def earliest(
self, *fields: Any, field_name: Optional[Any] = ...
) -> _T: ...
def latest(
self, *fields: Any, field_name: Optional[Any] = ...
) -> _T: ...
def first(self) -> Optional[Union[Dict[str, int], _T]]: ...
def last(self) -> Optional[_T]: ...
def in_bulk(
self, id_list: Any = ..., *, field_name: str = ...
) -> Union[Dict[int, models.Model], Dict[str, models.Model]]: ...
def delete(self) -> Tuple[int, Dict[str, int]]: ...
def update(self, **kwargs: Any) -> int: ...
def exists(self) -> bool: ...
def explain(
self, *, format: Optional[Any] = ..., **options: Any
) -> str: ...
def raw(
self,
raw_query: str,
params: Any = ...,
translations: Optional[Dict[str, str]] = ...,
using: None = ...,
) -> RawQuerySet: ...
def values(self, *fields: Any, **expressions: Any) -> QuerySet: ...
def values_list(
self, *fields: Any, flat: bool = ..., named: bool = ...
) -> QuerySet: ...
def dates(
self, field_name: str, kind: str, order: str = ...
) -> QuerySet: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...
) -> QuerySet: ...
def none(self) -> QuerySet[_T]: ...
def all(self) -> QuerySet[_T]: ...
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def complex_filter(
self,
filter_obj: Any,
) -> QuerySet[_T]: ...
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
def select_for_update(
self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...
) -> QuerySet: ...
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
def extra(
self,
select: Optional[
Union[Dict[str, int], Dict[str, str], OrderedDict]
] = ...,
where: Optional[List[str]] = ...,
params: Optional[Union[List[int], List[str]]] = ...,
tables: Optional[List[str]] = ...,
order_by: Optional[Union[List[str], Tuple[str]]] = ...,
select_params: Optional[Union[List[int], List[str], Tuple[int]]] = ...,
) -> QuerySet[_T]: ...
def reverse(self) -> QuerySet[_T]: ...
def defer(self, *fields: Any) -> QuerySet[_T]: ...
def only(self, *fields: Any) -> QuerySet[_T]: ...
def using(self, alias: Optional[str]) -> QuerySet[_T]: ...
@property
def ordered(self) -> bool: ...
@property
def db(self) -> str: ...
def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ...
class RawQuerySet:
pass

View File

@@ -0,0 +1,40 @@
from typing import Optional, Callable
from mypy.nodes import SymbolTableNode, MDEF, Var
from mypy.plugin import Plugin, FunctionContext
from mypy.types import Type, CallableType, TypeOfAny, AnyType, Instance
def set_related_fields(ctx: FunctionContext) -> Type:
if 'related_name' not in ctx.context.arg_names:
return ctx.default_return_type
assert 'to' in ctx.context.arg_names
to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0]
if not isinstance(to_arg_value, CallableType):
return ctx.default_return_type
referred_to = to_arg_value.ret_type
related_name = ctx.context.args[ctx.context.arg_names.index('related_name')].value
outer_class_info = ctx.api.tscope.classes[-1]
queryset_type = ctx.api.named_generic_type('django.db.models.QuerySet',
args=[Instance(outer_class_info, [])])
related_var = Var(related_name,
queryset_type)
related_var.info = queryset_type.type
referred_to.type.names[related_name] = SymbolTableNode(MDEF, related_var)
return ctx.default_return_type
class RelatedFieldsPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == 'django.db.models.fields.related.ForeignKey':
return set_related_fields
return None
def plugin(version):
return RelatedFieldsPlugin

View File

@@ -1,2 +1,4 @@
[mypy] [mypy]
plugins = mypy_django_plugin.plugins.postgres_fields plugins =
mypy_django_plugin.plugins.postgres_fields,
mypy_django_plugin.plugins.related_fields

View File

@@ -1,14 +1,27 @@
[case testForeignKeyWithClass] [case testForeignKeyWithClass]
from django.db import models from django.db import models
class User(models.Model): class Publisher(models.Model):
pass pass
class Profile(models.Model): class Book(models.Model):
user = models.ForeignKey(to=User, on_delete=models.CASCADE) publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
profile = Profile() book = Book()
reveal_type(profile.user) # E: Revealed type is 'main.User*' reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*'
[out] [out]
[case testForeignKeyRelatedName]
from django.db import models
class Publisher(models.Model):
pass
class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE,
related_name='books')
publisher = Publisher()
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
[out]