add related_name manager to ForeignKey model

This commit is contained in:
Maxim Kurnikov
2018-11-11 02:08:05 +03:00
parent c32a7842d6
commit 2a16bb04e5
5 changed files with 239 additions and 6 deletions

View File

@@ -9,3 +9,4 @@ from .fields import (AutoField as AutoField,
TextField as TextField)
from .fields.related import (ForeignKey as ForeignKey)
from .deletion import CASCADE as CASCADE
from .query import QuerySet as QuerySet

View File

@@ -0,0 +1,177 @@
from collections import OrderedDict
from datetime import date, datetime
from typing import Generic, TypeVar, Optional, Any, Type, Dict, Union, overload, List, Iterator, Tuple, Callable
from django.db import models
_T = TypeVar('_T', bound=models.Model)
class QuerySet(Generic[_T]):
def __init__(
self,
model: Optional[Type[models.Model]] = ...,
query: Optional[Any] = ...,
using: Optional[str] = ...,
hints: Optional[Dict[str, models.Model]] = ...,
) -> None: ...
@classmethod
def as_manager(cls): ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T]: ...
def __bool__(self) -> bool: ...
@overload
def __getitem__(self, k: slice) -> QuerySet[_T]: ...
@overload
def __getitem__(self, k: int) -> _T: ...
@overload
def __getitem__(self, k: str) -> Any: ...
def __and__(self, other: QuerySet) -> QuerySet: ...
def __or__(self, other: QuerySet) -> QuerySet: ...
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
def aggregate(
self, *args: Any, **kwargs: Any
) -> Dict[str, Optional[Union[datetime, float]]]: ...
def count(self) -> int: ...
def get(
self, *args: Any, **kwargs: Any
) -> _T: ...
def create(self, **kwargs: Any) -> _T: ...
def bulk_create(
self,
objs: Union[Iterator[Any], List[models.Model]],
batch_size: Optional[int] = ...,
) -> List[_T]: ...
def get_or_create(
self,
defaults: Optional[Union[Dict[str, date], Dict[str, models.Model]]] = ...,
**kwargs: Any
) -> Tuple[_T, bool]: ...
def update_or_create(
self,
defaults: Optional[
Union[
Dict[str, Callable],
Dict[str, date],
Dict[str, models.Model],
Dict[str, str],
]
] = ...,
**kwargs: Any
) -> Tuple[_T, bool]: ...
def earliest(
self, *fields: Any, field_name: Optional[Any] = ...
) -> _T: ...
def latest(
self, *fields: Any, field_name: Optional[Any] = ...
) -> _T: ...
def first(self) -> Optional[Union[Dict[str, int], _T]]: ...
def last(self) -> Optional[_T]: ...
def in_bulk(
self, id_list: Any = ..., *, field_name: str = ...
) -> Union[Dict[int, models.Model], Dict[str, models.Model]]: ...
def delete(self) -> Tuple[int, Dict[str, int]]: ...
def update(self, **kwargs: Any) -> int: ...
def exists(self) -> bool: ...
def explain(
self, *, format: Optional[Any] = ..., **options: Any
) -> str: ...
def raw(
self,
raw_query: str,
params: Any = ...,
translations: Optional[Dict[str, str]] = ...,
using: None = ...,
) -> RawQuerySet: ...
def values(self, *fields: Any, **expressions: Any) -> QuerySet: ...
def values_list(
self, *fields: Any, flat: bool = ..., named: bool = ...
) -> QuerySet: ...
def dates(
self, field_name: str, kind: str, order: str = ...
) -> QuerySet: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...
) -> QuerySet: ...
def none(self) -> QuerySet[_T]: ...
def all(self) -> QuerySet[_T]: ...
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def complex_filter(
self,
filter_obj: Any,
) -> QuerySet[_T]: ...
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
def select_for_update(
self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...
) -> QuerySet: ...
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
def extra(
self,
select: Optional[
Union[Dict[str, int], Dict[str, str], OrderedDict]
] = ...,
where: Optional[List[str]] = ...,
params: Optional[Union[List[int], List[str]]] = ...,
tables: Optional[List[str]] = ...,
order_by: Optional[Union[List[str], Tuple[str]]] = ...,
select_params: Optional[Union[List[int], List[str], Tuple[int]]] = ...,
) -> QuerySet[_T]: ...
def reverse(self) -> QuerySet[_T]: ...
def defer(self, *fields: Any) -> QuerySet[_T]: ...
def only(self, *fields: Any) -> QuerySet[_T]: ...
def using(self, alias: Optional[str]) -> QuerySet[_T]: ...
@property
def ordered(self) -> bool: ...
@property
def db(self) -> str: ...
def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ...
class RawQuerySet:
pass

View File

@@ -0,0 +1,40 @@
from typing import Optional, Callable
from mypy.nodes import SymbolTableNode, MDEF, Var
from mypy.plugin import Plugin, FunctionContext
from mypy.types import Type, CallableType, TypeOfAny, AnyType, Instance
def set_related_fields(ctx: FunctionContext) -> Type:
if 'related_name' not in ctx.context.arg_names:
return ctx.default_return_type
assert 'to' in ctx.context.arg_names
to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0]
if not isinstance(to_arg_value, CallableType):
return ctx.default_return_type
referred_to = to_arg_value.ret_type
related_name = ctx.context.args[ctx.context.arg_names.index('related_name')].value
outer_class_info = ctx.api.tscope.classes[-1]
queryset_type = ctx.api.named_generic_type('django.db.models.QuerySet',
args=[Instance(outer_class_info, [])])
related_var = Var(related_name,
queryset_type)
related_var.info = queryset_type.type
referred_to.type.names[related_name] = SymbolTableNode(MDEF, related_var)
return ctx.default_return_type
class RelatedFieldsPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == 'django.db.models.fields.related.ForeignKey':
return set_related_fields
return None
def plugin(version):
return RelatedFieldsPlugin

View File

@@ -1,2 +1,4 @@
[mypy]
plugins = mypy_django_plugin.plugins.postgres_fields
plugins =
mypy_django_plugin.plugins.postgres_fields,
mypy_django_plugin.plugins.related_fields

View File

@@ -1,14 +1,27 @@
[case testForeignKeyWithClass]
from django.db import models
class User(models.Model):
class Publisher(models.Model):
pass
class Profile(models.Model):
user = models.ForeignKey(to=User, on_delete=models.CASCADE)
class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
profile = Profile()
reveal_type(profile.user) # E: Revealed type is 'main.User*'
book = Book()
reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*'
[out]
[case testForeignKeyRelatedName]
from django.db import models
class Publisher(models.Model):
pass
class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE,
related_name='books')
publisher = Publisher()
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
[out]