From 2a16bb04e532722b70ecd1426f2bbe950d77c947 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Sun, 11 Nov 2018 02:08:05 +0300 Subject: [PATCH] add related_name manager to ForeignKey model --- django-stubs/db/models/__init__.pyi | 1 + django-stubs/db/models/query.pyi | 177 +++++++++++++++++++ mypy_django_plugin/plugins/related_fields.py | 40 +++++ test/plugins.ini | 4 +- test/test-data/check-model-relations.test | 23 ++- 5 files changed, 239 insertions(+), 6 deletions(-) create mode 100644 django-stubs/db/models/query.pyi create mode 100644 mypy_django_plugin/plugins/related_fields.py diff --git a/django-stubs/db/models/__init__.pyi b/django-stubs/db/models/__init__.pyi index d167319..4f69bb8 100644 --- a/django-stubs/db/models/__init__.pyi +++ b/django-stubs/db/models/__init__.pyi @@ -9,3 +9,4 @@ from .fields import (AutoField as AutoField, TextField as TextField) from .fields.related import (ForeignKey as ForeignKey) from .deletion import CASCADE as CASCADE +from .query import QuerySet as QuerySet \ No newline at end of file diff --git a/django-stubs/db/models/query.pyi b/django-stubs/db/models/query.pyi new file mode 100644 index 0000000..b26ad45 --- /dev/null +++ b/django-stubs/db/models/query.pyi @@ -0,0 +1,177 @@ +from collections import OrderedDict +from datetime import date, datetime +from typing import Generic, TypeVar, Optional, Any, Type, Dict, Union, overload, List, Iterator, Tuple, Callable + +from django.db import models + +_T = TypeVar('_T', bound=models.Model) + + +class QuerySet(Generic[_T]): + def __init__( + self, + model: Optional[Type[models.Model]] = ..., + query: Optional[Any] = ..., + using: Optional[str] = ..., + hints: Optional[Dict[str, models.Model]] = ..., + ) -> None: ... + @classmethod + def as_manager(cls): ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[_T]: ... + def __bool__(self) -> bool: ... + @overload + def __getitem__(self, k: slice) -> QuerySet[_T]: ... + @overload + def __getitem__(self, k: int) -> _T: ... + @overload + def __getitem__(self, k: str) -> Any: ... + def __and__(self, other: QuerySet) -> QuerySet: ... + def __or__(self, other: QuerySet) -> QuerySet: ... + + def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ... + def aggregate( + self, *args: Any, **kwargs: Any + ) -> Dict[str, Optional[Union[datetime, float]]]: ... + def count(self) -> int: ... + def get( + self, *args: Any, **kwargs: Any + ) -> _T: ... + def create(self, **kwargs: Any) -> _T: ... + def bulk_create( + self, + objs: Union[Iterator[Any], List[models.Model]], + batch_size: Optional[int] = ..., + ) -> List[_T]: ... + + def get_or_create( + self, + defaults: Optional[Union[Dict[str, date], Dict[str, models.Model]]] = ..., + **kwargs: Any + ) -> Tuple[_T, bool]: ... + + def update_or_create( + self, + defaults: Optional[ + Union[ + Dict[str, Callable], + Dict[str, date], + Dict[str, models.Model], + Dict[str, str], + ] + ] = ..., + **kwargs: Any + ) -> Tuple[_T, bool]: ... + + def earliest( + self, *fields: Any, field_name: Optional[Any] = ... + ) -> _T: ... + + def latest( + self, *fields: Any, field_name: Optional[Any] = ... + ) -> _T: ... + + def first(self) -> Optional[Union[Dict[str, int], _T]]: ... + + def last(self) -> Optional[_T]: ... + + def in_bulk( + self, id_list: Any = ..., *, field_name: str = ... + ) -> Union[Dict[int, models.Model], Dict[str, models.Model]]: ... + + def delete(self) -> Tuple[int, Dict[str, int]]: ... + + def update(self, **kwargs: Any) -> int: ... + + def exists(self) -> bool: ... + + def explain( + self, *, format: Optional[Any] = ..., **options: Any + ) -> str: ... + + def raw( + self, + raw_query: str, + params: Any = ..., + translations: Optional[Dict[str, str]] = ..., + using: None = ..., + ) -> RawQuerySet: ... + + def values(self, *fields: Any, **expressions: Any) -> QuerySet: ... + + def values_list( + self, *fields: Any, flat: bool = ..., named: bool = ... + ) -> QuerySet: ... + + def dates( + self, field_name: str, kind: str, order: str = ... + ) -> QuerySet: ... + + def datetimes( + self, field_name: str, kind: str, order: str = ..., tzinfo: None = ... + ) -> QuerySet: ... + + def none(self) -> QuerySet[_T]: ... + + def all(self) -> QuerySet[_T]: ... + + def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ... + + def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ... + + def complex_filter( + self, + filter_obj: Any, + ) -> QuerySet[_T]: ... + + def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ... + + def intersection(self, *other_qs: Any) -> QuerySet[_T]: ... + + def difference(self, *other_qs: Any) -> QuerySet[_T]: ... + + def select_for_update( + self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ... + ) -> QuerySet: ... + + def select_related(self, *fields: Any) -> QuerySet[_T]: ... + + def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ... + + def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ... + + def order_by(self, *field_names: Any) -> QuerySet[_T]: ... + + def distinct(self, *field_names: Any) -> QuerySet[_T]: ... + + def extra( + self, + select: Optional[ + Union[Dict[str, int], Dict[str, str], OrderedDict] + ] = ..., + where: Optional[List[str]] = ..., + params: Optional[Union[List[int], List[str]]] = ..., + tables: Optional[List[str]] = ..., + order_by: Optional[Union[List[str], Tuple[str]]] = ..., + select_params: Optional[Union[List[int], List[str], Tuple[int]]] = ..., + ) -> QuerySet[_T]: ... + + def reverse(self) -> QuerySet[_T]: ... + + def defer(self, *fields: Any) -> QuerySet[_T]: ... + + def only(self, *fields: Any) -> QuerySet[_T]: ... + + def using(self, alias: Optional[str]) -> QuerySet[_T]: ... + + @property + def ordered(self) -> bool: ... + + @property + def db(self) -> str: ... + + def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ... + + +class RawQuerySet: + pass \ No newline at end of file diff --git a/mypy_django_plugin/plugins/related_fields.py b/mypy_django_plugin/plugins/related_fields.py new file mode 100644 index 0000000..a62eea6 --- /dev/null +++ b/mypy_django_plugin/plugins/related_fields.py @@ -0,0 +1,40 @@ +from typing import Optional, Callable + +from mypy.nodes import SymbolTableNode, MDEF, Var +from mypy.plugin import Plugin, FunctionContext +from mypy.types import Type, CallableType, TypeOfAny, AnyType, Instance + + +def set_related_fields(ctx: FunctionContext) -> Type: + if 'related_name' not in ctx.context.arg_names: + return ctx.default_return_type + + assert 'to' in ctx.context.arg_names + to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0] + if not isinstance(to_arg_value, CallableType): + return ctx.default_return_type + + referred_to = to_arg_value.ret_type + related_name = ctx.context.args[ctx.context.arg_names.index('related_name')].value + outer_class_info = ctx.api.tscope.classes[-1] + + queryset_type = ctx.api.named_generic_type('django.db.models.QuerySet', + args=[Instance(outer_class_info, [])]) + related_var = Var(related_name, + queryset_type) + related_var.info = queryset_type.type + + referred_to.type.names[related_name] = SymbolTableNode(MDEF, related_var) + return ctx.default_return_type + + +class RelatedFieldsPlugin(Plugin): + def get_function_hook(self, fullname: str + ) -> Optional[Callable[[FunctionContext], Type]]: + if fullname == 'django.db.models.fields.related.ForeignKey': + return set_related_fields + return None + + +def plugin(version): + return RelatedFieldsPlugin diff --git a/test/plugins.ini b/test/plugins.ini index f8a5880..353143d 100644 --- a/test/plugins.ini +++ b/test/plugins.ini @@ -1,2 +1,4 @@ [mypy] -plugins = mypy_django_plugin.plugins.postgres_fields \ No newline at end of file +plugins = + mypy_django_plugin.plugins.postgres_fields, + mypy_django_plugin.plugins.related_fields \ No newline at end of file diff --git a/test/test-data/check-model-relations.test b/test/test-data/check-model-relations.test index 33ca653..167458a 100644 --- a/test/test-data/check-model-relations.test +++ b/test/test-data/check-model-relations.test @@ -1,14 +1,27 @@ [case testForeignKeyWithClass] from django.db import models -class User(models.Model): +class Publisher(models.Model): pass -class Profile(models.Model): - user = models.ForeignKey(to=User, on_delete=models.CASCADE) +class Book(models.Model): + publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) -profile = Profile() -reveal_type(profile.user) # E: Revealed type is 'main.User*' +book = Book() +reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*' [out] +[case testForeignKeyRelatedName] +from django.db import models + +class Publisher(models.Model): + pass + +class Book(models.Model): + publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE, + related_name='books') + +publisher = Publisher() +reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]' +[out]