mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-09 05:24:53 +08:00
more fixes for django stubs, first attempt for a plugin
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from typing import Callable, Optional, Type
|
from typing import Callable, Optional, Type
|
||||||
|
|
||||||
|
from django.contrib.sessions.backends.base import SessionBase
|
||||||
from django.core.handlers.wsgi import WSGIRequest
|
from django.core.handlers.wsgi import WSGIRequest
|
||||||
from django.http.request import HttpRequest
|
from django.http.request import HttpRequest
|
||||||
from django.http.response import HttpResponseBase
|
from django.http.response import HttpResponseBase
|
||||||
@@ -8,7 +9,7 @@ from django.utils.deprecation import MiddlewareMixin
|
|||||||
|
|
||||||
class SessionMiddleware(MiddlewareMixin):
|
class SessionMiddleware(MiddlewareMixin):
|
||||||
get_response: Callable[[WSGIRequest], HttpResponseBase] = ...
|
get_response: Callable[[WSGIRequest], HttpResponseBase] = ...
|
||||||
SessionStore: Type[SessionStore] = ...
|
SessionStore: Type[SessionBase] = ...
|
||||||
|
|
||||||
def __init__(self, get_response: Optional[Callable] = ...) -> None: ...
|
def __init__(self, get_response: Optional[Callable] = ...) -> None: ...
|
||||||
|
|
||||||
|
|||||||
17
django-stubs/core/cache/__init__.pyi
vendored
17
django-stubs/core/cache/__init__.pyi
vendored
@@ -1,25 +1,32 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Callable, Dict, Optional, Union
|
from typing import Any, Callable, Dict, Union
|
||||||
|
|
||||||
from django.core.cache.backends.base import BaseCache as BaseCache
|
from django.core.cache.backends.base import BaseCache as BaseCache
|
||||||
from django.core.cache.backends.base import CacheKeyWarning as CacheKeyWarning
|
|
||||||
from django.core.cache.backends.base import \
|
|
||||||
InvalidCacheBackendError as InvalidCacheBackendError
|
|
||||||
|
|
||||||
DEFAULT_CACHE_ALIAS: str
|
DEFAULT_CACHE_ALIAS: str
|
||||||
|
|
||||||
|
|
||||||
class CacheHandler:
|
class CacheHandler:
|
||||||
def __init__(self) -> None: ...
|
def __init__(self) -> None: ...
|
||||||
|
|
||||||
def __getitem__(self, alias: str) -> BaseCache: ...
|
def __getitem__(self, alias: str) -> BaseCache: ...
|
||||||
|
|
||||||
def all(self): ...
|
def all(self): ...
|
||||||
|
|
||||||
|
|
||||||
class DefaultCacheProxy:
|
class DefaultCacheProxy:
|
||||||
def __getattr__(
|
def __getattr__(
|
||||||
self, name: str
|
self, name: str
|
||||||
) -> Union[Callable, Dict[str, float], OrderedDict, int]: ...
|
) -> Union[Callable, Dict[str, float], OrderedDict, int]: ...
|
||||||
|
|
||||||
def __setattr__(self, name: str, value: Callable) -> None: ...
|
def __setattr__(self, name: str, value: Callable) -> None: ...
|
||||||
|
|
||||||
def __delattr__(self, name: Any): ...
|
def __delattr__(self, name: Any): ...
|
||||||
|
|
||||||
def __contains__(self, key: str) -> bool: ...
|
def __contains__(self, key: str) -> bool: ...
|
||||||
|
|
||||||
def __eq__(self, other: Any): ...
|
def __eq__(self, other: Any): ...
|
||||||
|
|
||||||
|
|
||||||
cache: Any
|
cache: Any
|
||||||
|
caches: CacheHandler
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from io import BufferedReader, StringIO
|
||||||
from typing import Any, Iterator, Optional, Union
|
from typing import Any, Iterator, Optional, Union
|
||||||
|
|
||||||
from django.core.files.utils import FileProxyMixin
|
from django.core.files.utils import FileProxyMixin
|
||||||
@@ -5,7 +6,7 @@ from django.core.files.utils import FileProxyMixin
|
|||||||
|
|
||||||
class File(FileProxyMixin):
|
class File(FileProxyMixin):
|
||||||
DEFAULT_CHUNK_SIZE: Any = ...
|
DEFAULT_CHUNK_SIZE: Any = ...
|
||||||
file: _io.BufferedReader = ...
|
file: BufferedReader = ...
|
||||||
name: str = ...
|
name: str = ...
|
||||||
mode: str = ...
|
mode: str = ...
|
||||||
def __init__(self, file: Any, name: Optional[str] = ...) -> None: ...
|
def __init__(self, file: Any, name: Optional[str] = ...) -> None: ...
|
||||||
@@ -14,7 +15,7 @@ class File(FileProxyMixin):
|
|||||||
def size(self) -> int: ...
|
def size(self) -> int: ...
|
||||||
def chunks(
|
def chunks(
|
||||||
self, chunk_size: Optional[int] = ...
|
self, chunk_size: Optional[int] = ...
|
||||||
) -> Iterator[Union[bytes, str]]: ...
|
) -> Iterator[Union[bytes, bytearray]]: ...
|
||||||
def multiple_chunks(self, chunk_size: Optional[Any] = ...): ...
|
def multiple_chunks(self, chunk_size: Optional[Any] = ...): ...
|
||||||
def __iter__(self) -> Iterator[Union[bytes, str]]: ...
|
def __iter__(self) -> Iterator[Union[bytes, str]]: ...
|
||||||
def __enter__(self) -> File: ...
|
def __enter__(self) -> File: ...
|
||||||
@@ -23,7 +24,7 @@ class File(FileProxyMixin):
|
|||||||
def close(self) -> None: ...
|
def close(self) -> None: ...
|
||||||
|
|
||||||
class ContentFile(File):
|
class ContentFile(File):
|
||||||
file: _io.StringIO
|
file: StringIO
|
||||||
name: None
|
name: None
|
||||||
size: Any = ...
|
size: Any = ...
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -1,76 +1,110 @@
|
|||||||
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
from typing import Any, Callable, List, Optional, Tuple, Type, Union, Generic, TypeVar
|
||||||
|
|
||||||
from django.core.exceptions import ObjectDoesNotExist
|
from django.core.exceptions import ObjectDoesNotExist
|
||||||
from django.db.models.base import Model
|
from django.db.models.base import Model
|
||||||
from django.db.models.expressions import F
|
from django.db.models.expressions import F
|
||||||
from django.db.models.fields.mixins import FieldCacheMixin
|
from django.db.models.fields.mixins import FieldCacheMixin
|
||||||
from django.db.models.fields.related import ForeignObject
|
from django.db.models.fields.related import ForeignObject, RelatedField, OneToOneField
|
||||||
from django.db.models.fields.reverse_related import ManyToManyRel, OneToOneRel
|
from django.db.models.fields.reverse_related import ManyToManyRel, OneToOneRel
|
||||||
from django.db.models.query import QuerySet
|
from django.db.models.query import QuerySet
|
||||||
|
|
||||||
|
_T = TypeVar('_T')
|
||||||
|
|
||||||
|
|
||||||
class ForwardManyToOneDescriptor:
|
class ForwardManyToOneDescriptor:
|
||||||
RelatedObjectDoesNotExist: Type[django.core.exceptions.ObjectDoesNotExist]
|
RelatedObjectDoesNotExist: Type[ObjectDoesNotExist]
|
||||||
field: django.db.models.fields.related.ForeignObject = ...
|
field: ForeignObject = ...
|
||||||
|
|
||||||
def __init__(self, field_with_rel: ForeignObject) -> None: ...
|
def __init__(self, field_with_rel: ForeignObject) -> None: ...
|
||||||
def RelatedObjectDoesNotExist(self) -> Type[ObjectDoesNotExist]: ...
|
|
||||||
def is_cached(self, instance: Model) -> bool: ...
|
def is_cached(self, instance: Model) -> bool: ...
|
||||||
|
|
||||||
def get_queryset(self, **hints: Any) -> QuerySet: ...
|
def get_queryset(self, **hints: Any) -> QuerySet: ...
|
||||||
|
|
||||||
def get_prefetch_queryset(
|
def get_prefetch_queryset(
|
||||||
self, instances: List[Model], queryset: Optional[QuerySet] = ...
|
self, instances: List[Model], queryset: Optional[QuerySet] = ...
|
||||||
) -> Tuple[QuerySet, Callable, Callable, bool, str, bool]: ...
|
) -> Tuple[QuerySet, Callable, Callable, bool, str, bool]: ...
|
||||||
|
|
||||||
def get_object(self, instance: Model) -> Model: ...
|
def get_object(self, instance: Model) -> Model: ...
|
||||||
|
|
||||||
def __get__(
|
def __get__(
|
||||||
self, instance: Optional[Model], cls: Type[Model] = ...
|
self, instance: Optional[Model], cls: Type[Model] = ...
|
||||||
) -> Optional[Union[Model, ForwardManyToOneDescriptor]]: ...
|
) -> Optional[Union[Model, ForwardManyToOneDescriptor]]: ...
|
||||||
|
|
||||||
def __set__(
|
def __set__(
|
||||||
self, instance: Model, value: Optional[Union[Model, F]]
|
self, instance: Model, value: Optional[Union[Model, F]]
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
def __reduce__(self) -> Tuple[Callable, Tuple[Type[Model], str]]: ...
|
def __reduce__(self) -> Tuple[Callable, Tuple[Type[Model], str]]: ...
|
||||||
|
|
||||||
|
|
||||||
class ForwardOneToOneDescriptor(ForwardManyToOneDescriptor):
|
class ForwardOneToOneDescriptor(ForwardManyToOneDescriptor):
|
||||||
RelatedObjectDoesNotExist: Type[django.core.exceptions.ObjectDoesNotExist]
|
RelatedObjectDoesNotExist: Type[ObjectDoesNotExist]
|
||||||
field: django.db.models.fields.related.OneToOneField
|
field: OneToOneField
|
||||||
|
|
||||||
def get_object(self, instance: Model) -> Model: ...
|
def get_object(self, instance: Model) -> Model: ...
|
||||||
|
|
||||||
def __set__(self, instance: Model, value: Optional[Model]) -> None: ...
|
def __set__(self, instance: Model, value: Optional[Model]) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class ReverseOneToOneDescriptor:
|
class ReverseOneToOneDescriptor:
|
||||||
RelatedObjectDoesNotExist: Type[django.core.exceptions.ObjectDoesNotExist]
|
RelatedObjectDoesNotExist: Type[ObjectDoesNotExist]
|
||||||
related: django.db.models.fields.reverse_related.OneToOneRel = ...
|
related: OneToOneRel = ...
|
||||||
|
|
||||||
def __init__(self, related: OneToOneRel) -> None: ...
|
def __init__(self, related: OneToOneRel) -> None: ...
|
||||||
def RelatedObjectDoesNotExist(self) -> Type[ObjectDoesNotExist]: ...
|
|
||||||
def is_cached(self, instance: Model) -> bool: ...
|
def is_cached(self, instance: Model) -> bool: ...
|
||||||
|
|
||||||
def get_queryset(self, **hints: Any) -> QuerySet: ...
|
def get_queryset(self, **hints: Any) -> QuerySet: ...
|
||||||
|
|
||||||
def get_prefetch_queryset(
|
def get_prefetch_queryset(
|
||||||
self, instances: List[Model], queryset: Optional[QuerySet] = ...
|
self, instances: List[Model], queryset: Optional[QuerySet] = ...
|
||||||
) -> Tuple[QuerySet, Callable, Callable, bool, str, bool]: ...
|
) -> Tuple[QuerySet, Callable, Callable, bool, str, bool]: ...
|
||||||
|
|
||||||
def __get__(
|
def __get__(
|
||||||
self, instance: Optional[Model], cls: Type[Model] = ...
|
self, instance: Optional[Model], cls: Type[Model] = ...
|
||||||
) -> Union[Model, ReverseOneToOneDescriptor]: ...
|
) -> Union[Model, ReverseOneToOneDescriptor]: ...
|
||||||
|
|
||||||
def __set__(self, instance: Model, value: Optional[Model]) -> None: ...
|
def __set__(self, instance: Model, value: Optional[Model]) -> None: ...
|
||||||
|
|
||||||
def __reduce__(self) -> Tuple[Callable, Tuple[Type[Model], str]]: ...
|
def __reduce__(self) -> Tuple[Callable, Tuple[Type[Model], str]]: ...
|
||||||
|
|
||||||
|
|
||||||
class ReverseManyToOneDescriptor:
|
class ReverseManyToOneDescriptor:
|
||||||
rel: django.db.models.fields.mixins.FieldCacheMixin = ...
|
rel: FieldCacheMixin = ...
|
||||||
field: django.db.models.fields.mixins.FieldCacheMixin = ...
|
field: FieldCacheMixin = ...
|
||||||
|
|
||||||
def __init__(self, rel: FieldCacheMixin) -> None: ...
|
def __init__(self, rel: FieldCacheMixin) -> None: ...
|
||||||
|
|
||||||
def related_manager_cls(self): ...
|
def related_manager_cls(self): ...
|
||||||
|
|
||||||
def __get__(
|
def __get__(
|
||||||
self, instance: Optional[Model], cls: Type[Model] = ...
|
self, instance: Optional[Model], cls: Type[Model] = ...
|
||||||
) -> ReverseManyToOneDescriptor: ...
|
) -> ReverseManyToOneDescriptor: ...
|
||||||
|
|
||||||
def __set__(self, instance: Model, value: List[Model]) -> Any: ...
|
def __set__(self, instance: Model, value: List[Model]) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
def create_reverse_many_to_one_manager(superclass: Any, rel: Any): ...
|
def create_reverse_many_to_one_manager(superclass: Any, rel: Any): ...
|
||||||
|
|
||||||
|
|
||||||
class ManyToManyDescriptor(ReverseManyToOneDescriptor):
|
class ManyToManyDescriptor(ReverseManyToOneDescriptor):
|
||||||
field: django.db.models.fields.related.RelatedField
|
field: RelatedField
|
||||||
rel: django.db.models.fields.reverse_related.ManyToManyRel
|
rel: ManyToManyRel
|
||||||
reverse: bool = ...
|
reverse: bool = ...
|
||||||
|
|
||||||
def __init__(self, rel: ManyToManyRel, reverse: bool = ...) -> None: ...
|
def __init__(self, rel: ManyToManyRel, reverse: bool = ...) -> None: ...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def through(self) -> Type[Model]: ...
|
def through(self) -> Type[Model]: ...
|
||||||
|
|
||||||
def related_manager_cls(self): ...
|
def related_manager_cls(self): ...
|
||||||
|
|
||||||
|
|
||||||
|
class _ForwardManyToManyManager(Generic[_T]):
|
||||||
|
def all(self) -> QuerySet: ...
|
||||||
|
|
||||||
|
|
||||||
def create_forward_many_to_many_manager(
|
def create_forward_many_to_many_manager(
|
||||||
superclass: Any, rel: Any, reverse: Any
|
superclass: Any, rel: Any, reverse: Any
|
||||||
): ...
|
) -> _ForwardManyToManyManager: ...
|
||||||
|
|||||||
@@ -138,12 +138,7 @@ class QuerySet(Generic[_T]):
|
|||||||
|
|
||||||
def get(
|
def get(
|
||||||
self, *args: Any, **kwargs: Any
|
self, *args: Any, **kwargs: Any
|
||||||
) -> Union[
|
) -> _T: ...
|
||||||
Dict[str, Union[date, Decimal, float, str]],
|
|
||||||
Tuple[Union[Decimal, str]],
|
|
||||||
Model,
|
|
||||||
str,
|
|
||||||
]: ...
|
|
||||||
|
|
||||||
def create(self, **kwargs: Any) -> _T: ...
|
def create(self, **kwargs: Any) -> _T: ...
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
from django.core.cache import BaseCache
|
||||||
from django.http.request import HttpRequest
|
from django.http.request import HttpRequest
|
||||||
from django.http.response import HttpResponse, HttpResponseBase
|
from django.http.response import HttpResponse, HttpResponseBase
|
||||||
from django.utils.deprecation import MiddlewareMixin
|
from django.utils.deprecation import MiddlewareMixin
|
||||||
@@ -9,32 +10,39 @@ class UpdateCacheMiddleware(MiddlewareMixin):
|
|||||||
cache_timeout: float = ...
|
cache_timeout: float = ...
|
||||||
key_prefix: str = ...
|
key_prefix: str = ...
|
||||||
cache_alias: str = ...
|
cache_alias: str = ...
|
||||||
cache: django.core.cache.backends.base.BaseCache = ...
|
cache: BaseCache = ...
|
||||||
get_response: Optional[Callable] = ...
|
get_response: Optional[Callable] = ...
|
||||||
|
|
||||||
def __init__(self, get_response: Optional[Callable] = ...) -> None: ...
|
def __init__(self, get_response: Optional[Callable] = ...) -> None: ...
|
||||||
|
|
||||||
def process_response(
|
def process_response(
|
||||||
self, request: HttpRequest, response: Union[HttpResponseBase, str]
|
self, request: HttpRequest, response: Union[HttpResponseBase, str]
|
||||||
) -> Union[HttpResponseBase, str]: ...
|
) -> Union[HttpResponseBase, str]: ...
|
||||||
|
|
||||||
|
|
||||||
class FetchFromCacheMiddleware(MiddlewareMixin):
|
class FetchFromCacheMiddleware(MiddlewareMixin):
|
||||||
key_prefix: str = ...
|
key_prefix: str = ...
|
||||||
cache_alias: str = ...
|
cache_alias: str = ...
|
||||||
cache: django.core.cache.backends.base.BaseCache = ...
|
cache: BaseCache = ...
|
||||||
get_response: Optional[Callable] = ...
|
get_response: Optional[Callable] = ...
|
||||||
|
|
||||||
def __init__(self, get_response: Optional[Callable] = ...) -> None: ...
|
def __init__(self, get_response: Optional[Callable] = ...) -> None: ...
|
||||||
|
|
||||||
def process_request(
|
def process_request(
|
||||||
self, request: HttpRequest
|
self, request: HttpRequest
|
||||||
) -> Optional[HttpResponse]: ...
|
) -> Optional[HttpResponse]: ...
|
||||||
|
|
||||||
|
|
||||||
class CacheMiddleware(UpdateCacheMiddleware, FetchFromCacheMiddleware):
|
class CacheMiddleware(UpdateCacheMiddleware, FetchFromCacheMiddleware):
|
||||||
get_response: None = ...
|
get_response: None = ...
|
||||||
key_prefix: str = ...
|
key_prefix: str = ...
|
||||||
cache_alias: str = ...
|
cache_alias: str = ...
|
||||||
cache_timeout: float = ...
|
cache_timeout: float = ...
|
||||||
cache: django.core.cache.backends.locmem.LocMemCache = ...
|
cache: BaseCache = ...
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
get_response: None = ...,
|
get_response: None = ...,
|
||||||
cache_timeout: Optional[float] = ...,
|
cache_timeout: Optional[float] = ...,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|||||||
@@ -3,3 +3,7 @@ from .testcases import (
|
|||||||
TransactionTestCase as TransactionTestCase,
|
TransactionTestCase as TransactionTestCase,
|
||||||
SimpleTestCase as SimpleTestCase
|
SimpleTestCase as SimpleTestCase
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
override_settings as override_settings
|
||||||
|
)
|
||||||
@@ -16,7 +16,8 @@ def lookup_django_model(mypy_api: TypeChecker, fullname: str) -> SymbolTableNode
|
|||||||
try:
|
try:
|
||||||
return mypy_api.modules[module].names[model_name]
|
return mypy_api.modules[module].names[model_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return mypy_api.modules['django.db.models'].names['Model']
|
return mypy_api.lookup_qualified('typing.Any')
|
||||||
|
# return mypy_api.modules['typing'].names['Any']
|
||||||
|
|
||||||
|
|
||||||
def get_app_model(model_name: str) -> str:
|
def get_app_model(model_name: str) -> str:
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
from mypy.nodes import AssignmentStmt, CallExpr, RefExpr, StrExpr
|
|
||||||
from mypy.plugin import Plugin, ClassDefContext
|
|
||||||
|
|
||||||
from mypy_django_plugin.helpers import get_app_model
|
|
||||||
from mypy_django_plugin.model_classes import DjangoModelsRegistry
|
|
||||||
|
|
||||||
|
|
||||||
# fields which real type is inside to= expression
|
|
||||||
REFERENCING_DB_FIELDS = {
|
|
||||||
'django.db.models.fields.related.ForeignKey',
|
|
||||||
'django.db.models.fields.related.OneToOneField'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None:
|
|
||||||
to_arg_value = rvalue.args[rvalue.arg_names.index('to')]
|
|
||||||
if isinstance(to_arg_value, StrExpr):
|
|
||||||
referred_model_fullname = get_app_model(to_arg_value.value)
|
|
||||||
else:
|
|
||||||
referred_model_fullname = to_arg_value.fullname
|
|
||||||
|
|
||||||
rvalue.callee.node.metadata['base'] = referred_model_fullname
|
|
||||||
|
|
||||||
|
|
||||||
class CollectModelsInformation(object):
|
|
||||||
def __init__(self, model_registry: DjangoModelsRegistry):
|
|
||||||
self.model_registry = model_registry
|
|
||||||
|
|
||||||
def __call__(self, model_definition: ClassDefContext) -> None:
|
|
||||||
self.model_registry.base_models.add(model_definition.cls.fullname)
|
|
||||||
|
|
||||||
for member in model_definition.cls.defs.body:
|
|
||||||
if isinstance(member, AssignmentStmt):
|
|
||||||
if len(member.lvalues) > 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
arg_name = member.lvalues[0].name
|
|
||||||
arg_name_as_id = arg_name + '_id'
|
|
||||||
|
|
||||||
rvalue = member.rvalue
|
|
||||||
if isinstance(rvalue, CallExpr):
|
|
||||||
if not isinstance(rvalue.callee, RefExpr):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if rvalue.callee.fullname in REFERENCING_DB_FIELDS:
|
|
||||||
if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey':
|
|
||||||
model_definition.cls.info.names[arg_name_as_id] = \
|
|
||||||
model_definition.api.lookup_fully_qualified('builtins.int')
|
|
||||||
|
|
||||||
if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField':
|
|
||||||
if 'related_name' in rvalue.arg_names:
|
|
||||||
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
|
|
||||||
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
|
|
||||||
|
|
||||||
if isinstance(referred_to_model, StrExpr):
|
|
||||||
referred_model_fullname = get_app_model(referred_to_model.value)
|
|
||||||
else:
|
|
||||||
referred_model_fullname = referred_to_model.fullname
|
|
||||||
|
|
||||||
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
|
|
||||||
referred_model.node.names[related_arg_value] = \
|
|
||||||
model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname)
|
|
||||||
|
|
||||||
return save_referred_to_model_in_metadata(rvalue)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDjangoModelsPlugin(Plugin):
|
|
||||||
model_registry = DjangoModelsRegistry()
|
|
||||||
|
|
||||||
def get_base_class_hook(self, fullname: str
|
|
||||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
|
||||||
if fullname in self.model_registry:
|
|
||||||
return CollectModelsInformation(self.model_registry)
|
|
||||||
|
|
||||||
return None
|
|
||||||
164
mypy_django_plugin/plugins/callbacks.py
Normal file
164
mypy_django_plugin/plugins/callbacks.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
import dataclasses
|
||||||
|
from mypy import types, nodes
|
||||||
|
from mypy.nodes import CallExpr, StrExpr, AssignmentStmt, RefExpr
|
||||||
|
from mypy.plugin import AttributeContext, ClassDefContext, SemanticAnalyzerPluginInterface
|
||||||
|
from mypy.types import Type, Instance, AnyType, TypeOfAny
|
||||||
|
|
||||||
|
from mypy_django_plugin.helpers import lookup_django_model, get_app_model
|
||||||
|
from mypy_django_plugin.model_classes import DjangoModelsRegistry
|
||||||
|
|
||||||
|
# mapping between field types and plain python types
|
||||||
|
DB_FIELDS_TO_TYPES = {
|
||||||
|
'django.db.models.fields.CharField': 'builtins.str',
|
||||||
|
'django.db.models.fields.TextField': 'builtins.str',
|
||||||
|
'django.db.models.fields.BooleanField': 'builtins.bool',
|
||||||
|
# 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]',
|
||||||
|
'django.db.models.fields.IntegerField': 'builtins.int',
|
||||||
|
'django.db.models.fields.AutoField': 'builtins.int',
|
||||||
|
'django.db.models.fields.FloatField': 'builtins.float',
|
||||||
|
'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict',
|
||||||
|
'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# def get_queryset_of(type_fullname: str):
|
||||||
|
# return model_definition.api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class DjangoPluginApi(object):
|
||||||
|
mypy_api: SemanticAnalyzerPluginInterface
|
||||||
|
|
||||||
|
def get_queryset_of(self, type_fullname: str = 'django.db.models.base.Model') -> types.Type:
|
||||||
|
queryset_sym = self.mypy_api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
|
||||||
|
if not queryset_sym:
|
||||||
|
return AnyType(TypeOfAny.from_error)
|
||||||
|
|
||||||
|
generic_arg = self.mypy_api.lookup_fully_qualified_or_none(type_fullname)
|
||||||
|
if not generic_arg:
|
||||||
|
return Instance(queryset_sym.node, [AnyType(TypeOfAny.from_error)])
|
||||||
|
|
||||||
|
return Instance(queryset_sym.node, [Instance(generic_arg.node, [])])
|
||||||
|
|
||||||
|
def generate_related_manager_assignment_stmt(self,
|
||||||
|
related_mngr_name: str,
|
||||||
|
queryset_argument_type_fullname: str) -> nodes.AssignmentStmt:
|
||||||
|
rvalue = nodes.TempNode(AnyType(TypeOfAny.special_form))
|
||||||
|
assignment = nodes.AssignmentStmt(lvalues=[nodes.NameExpr(related_mngr_name)],
|
||||||
|
rvalue=rvalue,
|
||||||
|
new_syntax=True,
|
||||||
|
type=self.get_queryset_of(queryset_argument_type_fullname))
|
||||||
|
return assignment
|
||||||
|
|
||||||
|
|
||||||
|
class DetermineFieldPythonTypeCallback(object):
|
||||||
|
def __init__(self, models_registry: DjangoModelsRegistry):
|
||||||
|
self.models_registry = models_registry
|
||||||
|
|
||||||
|
def __call__(self, attr_context: AttributeContext) -> Type:
|
||||||
|
default_attr_type = attr_context.default_attr_type
|
||||||
|
|
||||||
|
if isinstance(default_attr_type, Instance):
|
||||||
|
attr_type_fullname = default_attr_type.type.fullname()
|
||||||
|
if attr_type_fullname in DB_FIELDS_TO_TYPES:
|
||||||
|
return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname])
|
||||||
|
|
||||||
|
# if 'base' in default_attr_type.type.metadata:
|
||||||
|
# referred_base_model = default_attr_type.type.metadata['base']
|
||||||
|
|
||||||
|
if 'members' in attr_context.type.type.metadata:
|
||||||
|
arg_name = attr_context.context.name
|
||||||
|
if arg_name in attr_context.type.type.metadata['members']:
|
||||||
|
referred_base_model = attr_context.type.type.metadata['members'][arg_name]
|
||||||
|
|
||||||
|
typ = lookup_django_model(attr_context.api, referred_base_model)
|
||||||
|
try:
|
||||||
|
return Instance(typ.node, [])
|
||||||
|
except AssertionError as e:
|
||||||
|
return typ.type
|
||||||
|
|
||||||
|
return default_attr_type
|
||||||
|
|
||||||
|
|
||||||
|
# fields which real type is inside to= expression
|
||||||
|
REFERENCING_DB_FIELDS = {
|
||||||
|
'django.db.models.fields.related.ForeignKey',
|
||||||
|
'django.db.models.fields.related.OneToOneField'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None:
|
||||||
|
to_arg_value = rvalue.args[rvalue.arg_names.index('to')]
|
||||||
|
if isinstance(to_arg_value, StrExpr):
|
||||||
|
referred_model_fullname = get_app_model(to_arg_value.value)
|
||||||
|
else:
|
||||||
|
referred_model_fullname = to_arg_value.fullname
|
||||||
|
|
||||||
|
rvalue.callee.node.metadata['base'] = referred_model_fullname
|
||||||
|
|
||||||
|
|
||||||
|
class CollectModelsInformationCallback(object):
|
||||||
|
def __init__(self, model_registry: DjangoModelsRegistry):
|
||||||
|
self.model_registry = model_registry
|
||||||
|
|
||||||
|
def __call__(self, model_definition: ClassDefContext) -> None:
|
||||||
|
self.model_registry.base_models.add(model_definition.cls.fullname)
|
||||||
|
plugin_api = DjangoPluginApi(mypy_api=model_definition.api)
|
||||||
|
|
||||||
|
for member in model_definition.cls.defs.body:
|
||||||
|
if isinstance(member, AssignmentStmt):
|
||||||
|
if len(member.lvalues) > 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
arg_name = member.lvalues[0].name
|
||||||
|
arg_name_as_id = arg_name + '_id'
|
||||||
|
|
||||||
|
rvalue = member.rvalue
|
||||||
|
if isinstance(rvalue, CallExpr):
|
||||||
|
if not isinstance(rvalue.callee, RefExpr):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if rvalue.callee.fullname in REFERENCING_DB_FIELDS:
|
||||||
|
if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey':
|
||||||
|
model_definition.cls.info.names[arg_name_as_id] = \
|
||||||
|
model_definition.api.lookup_fully_qualified('builtins.int')
|
||||||
|
|
||||||
|
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
|
||||||
|
if isinstance(referred_to_model, StrExpr):
|
||||||
|
referred_model_fullname = get_app_model(referred_to_model.value)
|
||||||
|
else:
|
||||||
|
referred_model_fullname = referred_to_model.fullname
|
||||||
|
|
||||||
|
rvalue.callee.node.metadata['base'] = referred_model_fullname
|
||||||
|
|
||||||
|
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
|
||||||
|
|
||||||
|
if 'related_name' in rvalue.arg_names:
|
||||||
|
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
|
||||||
|
referred_model_class_def = referred_model.node.defn # type: nodes.ClassDef
|
||||||
|
|
||||||
|
referred_model_class_def.defs.body.append(
|
||||||
|
plugin_api.generate_related_manager_assignment_stmt(related_arg_value,
|
||||||
|
model_definition.cls.fullname))
|
||||||
|
|
||||||
|
if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField':
|
||||||
|
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
|
||||||
|
if isinstance(referred_to_model, StrExpr):
|
||||||
|
referred_model_fullname = get_app_model(referred_to_model.value)
|
||||||
|
else:
|
||||||
|
referred_model_fullname = referred_to_model.fullname
|
||||||
|
|
||||||
|
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
|
||||||
|
|
||||||
|
if 'related_name' in rvalue.arg_names:
|
||||||
|
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
|
||||||
|
referred_model.node.names[related_arg_value] = \
|
||||||
|
model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname)
|
||||||
|
|
||||||
|
rvalue.callee.node.metadata['base'] = referred_model_fullname
|
||||||
|
|
||||||
|
rvalue.callee.node.metadata['name'] = arg_name
|
||||||
|
if 'members' not in model_definition.cls.info.metadata:
|
||||||
|
model_definition.cls.info.metadata['members'] = {}
|
||||||
|
|
||||||
|
model_definition.cls.info.metadata['members'][arg_name] = rvalue.callee.node.metadata.get('base', None)
|
||||||
@@ -1,54 +1,24 @@
|
|||||||
from typing import Optional, Callable
|
from typing import Optional, Callable, Type
|
||||||
|
|
||||||
from mypy.plugin import AttributeContext
|
from mypy.plugin import Plugin, ClassDefContext, AttributeContext
|
||||||
from mypy.types import Type, Instance
|
|
||||||
|
|
||||||
from mypy_django_plugin.helpers import lookup_django_model
|
|
||||||
from mypy_django_plugin.model_classes import DjangoModelsRegistry
|
from mypy_django_plugin.model_classes import DjangoModelsRegistry
|
||||||
from mypy_django_plugin.plugins.base import BaseDjangoModelsPlugin
|
from mypy_django_plugin.plugins.callbacks import CollectModelsInformationCallback, DetermineFieldPythonTypeCallback
|
||||||
|
|
||||||
# mapping between field types and plain python types
|
|
||||||
DB_FIELDS_TO_TYPES = {
|
|
||||||
'django.db.models.fields.CharField': 'builtins.str',
|
|
||||||
'django.db.models.fields.TextField': 'builtins.str',
|
|
||||||
'django.db.models.fields.BooleanField': 'builtins.bool',
|
|
||||||
# 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]',
|
|
||||||
'django.db.models.fields.IntegerField': 'builtins.int',
|
|
||||||
'django.db.models.fields.AutoField': 'builtins.int',
|
|
||||||
'django.db.models.fields.FloatField': 'builtins.float',
|
|
||||||
'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict',
|
|
||||||
'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DetermineFieldPythonTypeCallback(object):
|
class FieldToPythonTypePlugin(Plugin):
|
||||||
def __init__(self, models_registry: DjangoModelsRegistry):
|
model_registry = DjangoModelsRegistry()
|
||||||
self.models_registry = models_registry
|
|
||||||
|
|
||||||
def __call__(self, attr_context: AttributeContext) -> Type:
|
def get_base_class_hook(self, fullname: str
|
||||||
default_attr_type = attr_context.default_attr_type
|
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||||
|
if fullname in self.model_registry:
|
||||||
|
return CollectModelsInformationCallback(self.model_registry)
|
||||||
|
|
||||||
if isinstance(default_attr_type, Instance):
|
return None
|
||||||
attr_type_fullname = default_attr_type.type.fullname()
|
|
||||||
if attr_type_fullname in DB_FIELDS_TO_TYPES:
|
|
||||||
return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname])
|
|
||||||
|
|
||||||
if 'base' in default_attr_type.type.metadata:
|
|
||||||
referred_base_model = default_attr_type.type.metadata['base']
|
|
||||||
try:
|
|
||||||
node = lookup_django_model(attr_context.api, referred_base_model).node
|
|
||||||
return Instance(node, [])
|
|
||||||
except AssertionError as e:
|
|
||||||
print(e)
|
|
||||||
print('name to lookup:', referred_base_model)
|
|
||||||
pass
|
|
||||||
|
|
||||||
return default_attr_type
|
|
||||||
|
|
||||||
|
|
||||||
class FieldToPythonTypePlugin(BaseDjangoModelsPlugin):
|
|
||||||
def get_attribute_hook(self, fullname: str
|
def get_attribute_hook(self, fullname: str
|
||||||
) -> Optional[Callable[[AttributeContext], Type]]:
|
) -> Optional[Callable[[AttributeContext], Type]]:
|
||||||
|
# print(fullname)
|
||||||
classname, _, attrname = fullname.rpartition('.')
|
classname, _, attrname = fullname.rpartition('.')
|
||||||
if classname and classname in self.model_registry:
|
if classname and classname in self.model_registry:
|
||||||
return DetermineFieldPythonTypeCallback(self.model_registry)
|
return DetermineFieldPythonTypeCallback(self.model_registry)
|
||||||
|
|||||||
Reference in New Issue
Block a user