more fixes for django stubs, first attempt for a plugin

This commit is contained in:
Maxim Kurnikov
2018-10-12 01:56:25 +03:00
parent b93f589cff
commit 2cdefc4662
11 changed files with 273 additions and 165 deletions

View File

@@ -1,5 +1,6 @@
from typing import Callable, Optional, Type
from django.contrib.sessions.backends.base import SessionBase
from django.core.handlers.wsgi import WSGIRequest
from django.http.request import HttpRequest
from django.http.response import HttpResponseBase
@@ -8,7 +9,7 @@ from django.utils.deprecation import MiddlewareMixin
class SessionMiddleware(MiddlewareMixin):
get_response: Callable[[WSGIRequest], HttpResponseBase] = ...
SessionStore: Type[SessionStore] = ...
SessionStore: Type[SessionBase] = ...
def __init__(self, get_response: Optional[Callable] = ...) -> None: ...

View File

@@ -1,25 +1,32 @@
from collections import OrderedDict
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Union
from django.core.cache.backends.base import BaseCache as BaseCache
from django.core.cache.backends.base import CacheKeyWarning as CacheKeyWarning
from django.core.cache.backends.base import \
InvalidCacheBackendError as InvalidCacheBackendError
DEFAULT_CACHE_ALIAS: str
class CacheHandler:
def __init__(self) -> None: ...
def __getitem__(self, alias: str) -> BaseCache: ...
def all(self): ...
class DefaultCacheProxy:
def __getattr__(
self, name: str
self, name: str
) -> Union[Callable, Dict[str, float], OrderedDict, int]: ...
def __setattr__(self, name: str, value: Callable) -> None: ...
def __delattr__(self, name: Any): ...
def __contains__(self, key: str) -> bool: ...
def __eq__(self, other: Any): ...
cache: Any
caches: CacheHandler

View File

@@ -1,3 +1,4 @@
from io import BufferedReader, StringIO
from typing import Any, Iterator, Optional, Union
from django.core.files.utils import FileProxyMixin
@@ -5,7 +6,7 @@ from django.core.files.utils import FileProxyMixin
class File(FileProxyMixin):
DEFAULT_CHUNK_SIZE: Any = ...
file: _io.BufferedReader = ...
file: BufferedReader = ...
name: str = ...
mode: str = ...
def __init__(self, file: Any, name: Optional[str] = ...) -> None: ...
@@ -14,7 +15,7 @@ class File(FileProxyMixin):
def size(self) -> int: ...
def chunks(
self, chunk_size: Optional[int] = ...
) -> Iterator[Union[bytes, str]]: ...
) -> Iterator[Union[bytes, bytearray]]: ...
def multiple_chunks(self, chunk_size: Optional[Any] = ...): ...
def __iter__(self) -> Iterator[Union[bytes, str]]: ...
def __enter__(self) -> File: ...
@@ -23,7 +24,7 @@ class File(FileProxyMixin):
def close(self) -> None: ...
class ContentFile(File):
file: _io.StringIO
file: StringIO
name: None
size: Any = ...
def __init__(

View File

@@ -1,76 +1,110 @@
from typing import Any, Callable, List, Optional, Tuple, Type, Union
from typing import Any, Callable, List, Optional, Tuple, Type, Union, Generic, TypeVar
from django.core.exceptions import ObjectDoesNotExist
from django.db.models.base import Model
from django.db.models.expressions import F
from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.fields.related import ForeignObject
from django.db.models.fields.related import ForeignObject, RelatedField, OneToOneField
from django.db.models.fields.reverse_related import ManyToManyRel, OneToOneRel
from django.db.models.query import QuerySet
_T = TypeVar('_T')
class ForwardManyToOneDescriptor:
RelatedObjectDoesNotExist: Type[django.core.exceptions.ObjectDoesNotExist]
field: django.db.models.fields.related.ForeignObject = ...
RelatedObjectDoesNotExist: Type[ObjectDoesNotExist]
field: ForeignObject = ...
def __init__(self, field_with_rel: ForeignObject) -> None: ...
def RelatedObjectDoesNotExist(self) -> Type[ObjectDoesNotExist]: ...
def is_cached(self, instance: Model) -> bool: ...
def get_queryset(self, **hints: Any) -> QuerySet: ...
def get_prefetch_queryset(
self, instances: List[Model], queryset: Optional[QuerySet] = ...
self, instances: List[Model], queryset: Optional[QuerySet] = ...
) -> Tuple[QuerySet, Callable, Callable, bool, str, bool]: ...
def get_object(self, instance: Model) -> Model: ...
def __get__(
self, instance: Optional[Model], cls: Type[Model] = ...
self, instance: Optional[Model], cls: Type[Model] = ...
) -> Optional[Union[Model, ForwardManyToOneDescriptor]]: ...
def __set__(
self, instance: Model, value: Optional[Union[Model, F]]
self, instance: Model, value: Optional[Union[Model, F]]
) -> None: ...
def __reduce__(self) -> Tuple[Callable, Tuple[Type[Model], str]]: ...
class ForwardOneToOneDescriptor(ForwardManyToOneDescriptor):
RelatedObjectDoesNotExist: Type[django.core.exceptions.ObjectDoesNotExist]
field: django.db.models.fields.related.OneToOneField
RelatedObjectDoesNotExist: Type[ObjectDoesNotExist]
field: OneToOneField
def get_object(self, instance: Model) -> Model: ...
def __set__(self, instance: Model, value: Optional[Model]) -> None: ...
class ReverseOneToOneDescriptor:
RelatedObjectDoesNotExist: Type[django.core.exceptions.ObjectDoesNotExist]
related: django.db.models.fields.reverse_related.OneToOneRel = ...
RelatedObjectDoesNotExist: Type[ObjectDoesNotExist]
related: OneToOneRel = ...
def __init__(self, related: OneToOneRel) -> None: ...
def RelatedObjectDoesNotExist(self) -> Type[ObjectDoesNotExist]: ...
def is_cached(self, instance: Model) -> bool: ...
def get_queryset(self, **hints: Any) -> QuerySet: ...
def get_prefetch_queryset(
self, instances: List[Model], queryset: Optional[QuerySet] = ...
self, instances: List[Model], queryset: Optional[QuerySet] = ...
) -> Tuple[QuerySet, Callable, Callable, bool, str, bool]: ...
def __get__(
self, instance: Optional[Model], cls: Type[Model] = ...
self, instance: Optional[Model], cls: Type[Model] = ...
) -> Union[Model, ReverseOneToOneDescriptor]: ...
def __set__(self, instance: Model, value: Optional[Model]) -> None: ...
def __reduce__(self) -> Tuple[Callable, Tuple[Type[Model], str]]: ...
class ReverseManyToOneDescriptor:
rel: django.db.models.fields.mixins.FieldCacheMixin = ...
field: django.db.models.fields.mixins.FieldCacheMixin = ...
rel: FieldCacheMixin = ...
field: FieldCacheMixin = ...
def __init__(self, rel: FieldCacheMixin) -> None: ...
def related_manager_cls(self): ...
def __get__(
self, instance: Optional[Model], cls: Type[Model] = ...
self, instance: Optional[Model], cls: Type[Model] = ...
) -> ReverseManyToOneDescriptor: ...
def __set__(self, instance: Model, value: List[Model]) -> Any: ...
def create_reverse_many_to_one_manager(superclass: Any, rel: Any): ...
class ManyToManyDescriptor(ReverseManyToOneDescriptor):
field: django.db.models.fields.related.RelatedField
rel: django.db.models.fields.reverse_related.ManyToManyRel
field: RelatedField
rel: ManyToManyRel
reverse: bool = ...
def __init__(self, rel: ManyToManyRel, reverse: bool = ...) -> None: ...
@property
def through(self) -> Type[Model]: ...
def related_manager_cls(self): ...
class _ForwardManyToManyManager(Generic[_T]):
def all(self) -> QuerySet: ...
def create_forward_many_to_many_manager(
superclass: Any, rel: Any, reverse: Any
): ...
superclass: Any, rel: Any, reverse: Any
) -> _ForwardManyToManyManager: ...

View File

@@ -138,12 +138,7 @@ class QuerySet(Generic[_T]):
def get(
self, *args: Any, **kwargs: Any
) -> Union[
Dict[str, Union[date, Decimal, float, str]],
Tuple[Union[Decimal, str]],
Model,
str,
]: ...
) -> _T: ...
def create(self, **kwargs: Any) -> _T: ...

View File

@@ -1,5 +1,6 @@
from typing import Any, Callable, Optional, Union
from django.core.cache import BaseCache
from django.http.request import HttpRequest
from django.http.response import HttpResponse, HttpResponseBase
from django.utils.deprecation import MiddlewareMixin
@@ -9,32 +10,39 @@ class UpdateCacheMiddleware(MiddlewareMixin):
cache_timeout: float = ...
key_prefix: str = ...
cache_alias: str = ...
cache: django.core.cache.backends.base.BaseCache = ...
cache: BaseCache = ...
get_response: Optional[Callable] = ...
def __init__(self, get_response: Optional[Callable] = ...) -> None: ...
def process_response(
self, request: HttpRequest, response: Union[HttpResponseBase, str]
self, request: HttpRequest, response: Union[HttpResponseBase, str]
) -> Union[HttpResponseBase, str]: ...
class FetchFromCacheMiddleware(MiddlewareMixin):
key_prefix: str = ...
cache_alias: str = ...
cache: django.core.cache.backends.base.BaseCache = ...
cache: BaseCache = ...
get_response: Optional[Callable] = ...
def __init__(self, get_response: Optional[Callable] = ...) -> None: ...
def process_request(
self, request: HttpRequest
self, request: HttpRequest
) -> Optional[HttpResponse]: ...
class CacheMiddleware(UpdateCacheMiddleware, FetchFromCacheMiddleware):
get_response: None = ...
key_prefix: str = ...
cache_alias: str = ...
cache_timeout: float = ...
cache: django.core.cache.backends.locmem.LocMemCache = ...
cache: BaseCache = ...
def __init__(
self,
get_response: None = ...,
cache_timeout: Optional[float] = ...,
**kwargs: Any
self,
get_response: None = ...,
cache_timeout: Optional[float] = ...,
**kwargs: Any
) -> None: ...

View File

@@ -2,4 +2,8 @@ from .testcases import (
TestCase as TestCase,
TransactionTestCase as TransactionTestCase,
SimpleTestCase as SimpleTestCase
)
from .utils import (
override_settings as override_settings
)

View File

@@ -16,7 +16,8 @@ def lookup_django_model(mypy_api: TypeChecker, fullname: str) -> SymbolTableNode
try:
return mypy_api.modules[module].names[model_name]
except KeyError:
return mypy_api.modules['django.db.models'].names['Model']
return mypy_api.lookup_qualified('typing.Any')
# return mypy_api.modules['typing'].names['Any']
def get_app_model(model_name: str) -> str:

View File

@@ -1,77 +0,0 @@
from typing import Callable, Optional
from mypy.nodes import AssignmentStmt, CallExpr, RefExpr, StrExpr
from mypy.plugin import Plugin, ClassDefContext
from mypy_django_plugin.helpers import get_app_model
from mypy_django_plugin.model_classes import DjangoModelsRegistry
# fields which real type is inside to= expression
REFERENCING_DB_FIELDS = {
'django.db.models.fields.related.ForeignKey',
'django.db.models.fields.related.OneToOneField'
}
def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None:
to_arg_value = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(to_arg_value, StrExpr):
referred_model_fullname = get_app_model(to_arg_value.value)
else:
referred_model_fullname = to_arg_value.fullname
rvalue.callee.node.metadata['base'] = referred_model_fullname
class CollectModelsInformation(object):
def __init__(self, model_registry: DjangoModelsRegistry):
self.model_registry = model_registry
def __call__(self, model_definition: ClassDefContext) -> None:
self.model_registry.base_models.add(model_definition.cls.fullname)
for member in model_definition.cls.defs.body:
if isinstance(member, AssignmentStmt):
if len(member.lvalues) > 1:
return None
arg_name = member.lvalues[0].name
arg_name_as_id = arg_name + '_id'
rvalue = member.rvalue
if isinstance(rvalue, CallExpr):
if not isinstance(rvalue.callee, RefExpr):
return None
if rvalue.callee.fullname in REFERENCING_DB_FIELDS:
if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey':
model_definition.cls.info.names[arg_name_as_id] = \
model_definition.api.lookup_fully_qualified('builtins.int')
if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField':
if 'related_name' in rvalue.arg_names:
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
if isinstance(referred_to_model, StrExpr):
referred_model_fullname = get_app_model(referred_to_model.value)
else:
referred_model_fullname = referred_to_model.fullname
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
referred_model.node.names[related_arg_value] = \
model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname)
return save_referred_to_model_in_metadata(rvalue)
class BaseDjangoModelsPlugin(Plugin):
model_registry = DjangoModelsRegistry()
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in self.model_registry:
return CollectModelsInformation(self.model_registry)
return None

View File

@@ -0,0 +1,164 @@
import dataclasses
from mypy import types, nodes
from mypy.nodes import CallExpr, StrExpr, AssignmentStmt, RefExpr
from mypy.plugin import AttributeContext, ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.types import Type, Instance, AnyType, TypeOfAny
from mypy_django_plugin.helpers import lookup_django_model, get_app_model
from mypy_django_plugin.model_classes import DjangoModelsRegistry
# mapping between field types and plain python types
DB_FIELDS_TO_TYPES = {
'django.db.models.fields.CharField': 'builtins.str',
'django.db.models.fields.TextField': 'builtins.str',
'django.db.models.fields.BooleanField': 'builtins.bool',
# 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]',
'django.db.models.fields.IntegerField': 'builtins.int',
'django.db.models.fields.AutoField': 'builtins.int',
'django.db.models.fields.FloatField': 'builtins.float',
'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict',
'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable'
}
# def get_queryset_of(type_fullname: str):
# return model_definition.api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
@dataclasses.dataclass
class DjangoPluginApi(object):
mypy_api: SemanticAnalyzerPluginInterface
def get_queryset_of(self, type_fullname: str = 'django.db.models.base.Model') -> types.Type:
queryset_sym = self.mypy_api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
if not queryset_sym:
return AnyType(TypeOfAny.from_error)
generic_arg = self.mypy_api.lookup_fully_qualified_or_none(type_fullname)
if not generic_arg:
return Instance(queryset_sym.node, [AnyType(TypeOfAny.from_error)])
return Instance(queryset_sym.node, [Instance(generic_arg.node, [])])
def generate_related_manager_assignment_stmt(self,
related_mngr_name: str,
queryset_argument_type_fullname: str) -> nodes.AssignmentStmt:
rvalue = nodes.TempNode(AnyType(TypeOfAny.special_form))
assignment = nodes.AssignmentStmt(lvalues=[nodes.NameExpr(related_mngr_name)],
rvalue=rvalue,
new_syntax=True,
type=self.get_queryset_of(queryset_argument_type_fullname))
return assignment
class DetermineFieldPythonTypeCallback(object):
def __init__(self, models_registry: DjangoModelsRegistry):
self.models_registry = models_registry
def __call__(self, attr_context: AttributeContext) -> Type:
default_attr_type = attr_context.default_attr_type
if isinstance(default_attr_type, Instance):
attr_type_fullname = default_attr_type.type.fullname()
if attr_type_fullname in DB_FIELDS_TO_TYPES:
return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname])
# if 'base' in default_attr_type.type.metadata:
# referred_base_model = default_attr_type.type.metadata['base']
if 'members' in attr_context.type.type.metadata:
arg_name = attr_context.context.name
if arg_name in attr_context.type.type.metadata['members']:
referred_base_model = attr_context.type.type.metadata['members'][arg_name]
typ = lookup_django_model(attr_context.api, referred_base_model)
try:
return Instance(typ.node, [])
except AssertionError as e:
return typ.type
return default_attr_type
# fields which real type is inside to= expression
REFERENCING_DB_FIELDS = {
'django.db.models.fields.related.ForeignKey',
'django.db.models.fields.related.OneToOneField'
}
def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None:
to_arg_value = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(to_arg_value, StrExpr):
referred_model_fullname = get_app_model(to_arg_value.value)
else:
referred_model_fullname = to_arg_value.fullname
rvalue.callee.node.metadata['base'] = referred_model_fullname
class CollectModelsInformationCallback(object):
def __init__(self, model_registry: DjangoModelsRegistry):
self.model_registry = model_registry
def __call__(self, model_definition: ClassDefContext) -> None:
self.model_registry.base_models.add(model_definition.cls.fullname)
plugin_api = DjangoPluginApi(mypy_api=model_definition.api)
for member in model_definition.cls.defs.body:
if isinstance(member, AssignmentStmt):
if len(member.lvalues) > 1:
return None
arg_name = member.lvalues[0].name
arg_name_as_id = arg_name + '_id'
rvalue = member.rvalue
if isinstance(rvalue, CallExpr):
if not isinstance(rvalue.callee, RefExpr):
return None
if rvalue.callee.fullname in REFERENCING_DB_FIELDS:
if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey':
model_definition.cls.info.names[arg_name_as_id] = \
model_definition.api.lookup_fully_qualified('builtins.int')
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(referred_to_model, StrExpr):
referred_model_fullname = get_app_model(referred_to_model.value)
else:
referred_model_fullname = referred_to_model.fullname
rvalue.callee.node.metadata['base'] = referred_model_fullname
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
if 'related_name' in rvalue.arg_names:
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
referred_model_class_def = referred_model.node.defn # type: nodes.ClassDef
referred_model_class_def.defs.body.append(
plugin_api.generate_related_manager_assignment_stmt(related_arg_value,
model_definition.cls.fullname))
if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField':
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(referred_to_model, StrExpr):
referred_model_fullname = get_app_model(referred_to_model.value)
else:
referred_model_fullname = referred_to_model.fullname
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
if 'related_name' in rvalue.arg_names:
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
referred_model.node.names[related_arg_value] = \
model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname)
rvalue.callee.node.metadata['base'] = referred_model_fullname
rvalue.callee.node.metadata['name'] = arg_name
if 'members' not in model_definition.cls.info.metadata:
model_definition.cls.info.metadata['members'] = {}
model_definition.cls.info.metadata['members'][arg_name] = rvalue.callee.node.metadata.get('base', None)

View File

@@ -1,54 +1,24 @@
from typing import Optional, Callable
from typing import Optional, Callable, Type
from mypy.plugin import AttributeContext
from mypy.types import Type, Instance
from mypy.plugin import Plugin, ClassDefContext, AttributeContext
from mypy_django_plugin.helpers import lookup_django_model
from mypy_django_plugin.model_classes import DjangoModelsRegistry
from mypy_django_plugin.plugins.base import BaseDjangoModelsPlugin
# mapping between field types and plain python types
DB_FIELDS_TO_TYPES = {
'django.db.models.fields.CharField': 'builtins.str',
'django.db.models.fields.TextField': 'builtins.str',
'django.db.models.fields.BooleanField': 'builtins.bool',
# 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]',
'django.db.models.fields.IntegerField': 'builtins.int',
'django.db.models.fields.AutoField': 'builtins.int',
'django.db.models.fields.FloatField': 'builtins.float',
'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict',
'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable'
}
from mypy_django_plugin.plugins.callbacks import CollectModelsInformationCallback, DetermineFieldPythonTypeCallback
class DetermineFieldPythonTypeCallback(object):
def __init__(self, models_registry: DjangoModelsRegistry):
self.models_registry = models_registry
class FieldToPythonTypePlugin(Plugin):
model_registry = DjangoModelsRegistry()
def __call__(self, attr_context: AttributeContext) -> Type:
default_attr_type = attr_context.default_attr_type
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in self.model_registry:
return CollectModelsInformationCallback(self.model_registry)
if isinstance(default_attr_type, Instance):
attr_type_fullname = default_attr_type.type.fullname()
if attr_type_fullname in DB_FIELDS_TO_TYPES:
return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname])
return None
if 'base' in default_attr_type.type.metadata:
referred_base_model = default_attr_type.type.metadata['base']
try:
node = lookup_django_model(attr_context.api, referred_base_model).node
return Instance(node, [])
except AssertionError as e:
print(e)
print('name to lookup:', referred_base_model)
pass
return default_attr_type
class FieldToPythonTypePlugin(BaseDjangoModelsPlugin):
def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
# print(fullname)
classname, _, attrname = fullname.rpartition('.')
if classname and classname in self.model_registry:
return DetermineFieldPythonTypeCallback(self.model_registry)