more fixes for django stubs, first attempt for a plugin

This commit is contained in:
Maxim Kurnikov
2018-10-12 01:56:25 +03:00
parent b93f589cff
commit 2cdefc4662
11 changed files with 273 additions and 165 deletions

View File

@@ -1,5 +1,6 @@
from typing import Callable, Optional, Type from typing import Callable, Optional, Type
from django.contrib.sessions.backends.base import SessionBase
from django.core.handlers.wsgi import WSGIRequest from django.core.handlers.wsgi import WSGIRequest
from django.http.request import HttpRequest from django.http.request import HttpRequest
from django.http.response import HttpResponseBase from django.http.response import HttpResponseBase
@@ -8,7 +9,7 @@ from django.utils.deprecation import MiddlewareMixin
class SessionMiddleware(MiddlewareMixin): class SessionMiddleware(MiddlewareMixin):
get_response: Callable[[WSGIRequest], HttpResponseBase] = ... get_response: Callable[[WSGIRequest], HttpResponseBase] = ...
SessionStore: Type[SessionStore] = ... SessionStore: Type[SessionBase] = ...
def __init__(self, get_response: Optional[Callable] = ...) -> None: ... def __init__(self, get_response: Optional[Callable] = ...) -> None: ...

View File

@@ -1,25 +1,32 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Callable, Dict, Optional, Union from typing import Any, Callable, Dict, Union
from django.core.cache.backends.base import BaseCache as BaseCache from django.core.cache.backends.base import BaseCache as BaseCache
from django.core.cache.backends.base import CacheKeyWarning as CacheKeyWarning
from django.core.cache.backends.base import \
InvalidCacheBackendError as InvalidCacheBackendError
DEFAULT_CACHE_ALIAS: str DEFAULT_CACHE_ALIAS: str
class CacheHandler: class CacheHandler:
def __init__(self) -> None: ... def __init__(self) -> None: ...
def __getitem__(self, alias: str) -> BaseCache: ... def __getitem__(self, alias: str) -> BaseCache: ...
def all(self): ... def all(self): ...
class DefaultCacheProxy: class DefaultCacheProxy:
def __getattr__( def __getattr__(
self, name: str self, name: str
) -> Union[Callable, Dict[str, float], OrderedDict, int]: ... ) -> Union[Callable, Dict[str, float], OrderedDict, int]: ...
def __setattr__(self, name: str, value: Callable) -> None: ... def __setattr__(self, name: str, value: Callable) -> None: ...
def __delattr__(self, name: Any): ... def __delattr__(self, name: Any): ...
def __contains__(self, key: str) -> bool: ... def __contains__(self, key: str) -> bool: ...
def __eq__(self, other: Any): ... def __eq__(self, other: Any): ...
cache: Any cache: Any
caches: CacheHandler

View File

@@ -1,3 +1,4 @@
from io import BufferedReader, StringIO
from typing import Any, Iterator, Optional, Union from typing import Any, Iterator, Optional, Union
from django.core.files.utils import FileProxyMixin from django.core.files.utils import FileProxyMixin
@@ -5,7 +6,7 @@ from django.core.files.utils import FileProxyMixin
class File(FileProxyMixin): class File(FileProxyMixin):
DEFAULT_CHUNK_SIZE: Any = ... DEFAULT_CHUNK_SIZE: Any = ...
file: _io.BufferedReader = ... file: BufferedReader = ...
name: str = ... name: str = ...
mode: str = ... mode: str = ...
def __init__(self, file: Any, name: Optional[str] = ...) -> None: ... def __init__(self, file: Any, name: Optional[str] = ...) -> None: ...
@@ -14,7 +15,7 @@ class File(FileProxyMixin):
def size(self) -> int: ... def size(self) -> int: ...
def chunks( def chunks(
self, chunk_size: Optional[int] = ... self, chunk_size: Optional[int] = ...
) -> Iterator[Union[bytes, str]]: ... ) -> Iterator[Union[bytes, bytearray]]: ...
def multiple_chunks(self, chunk_size: Optional[Any] = ...): ... def multiple_chunks(self, chunk_size: Optional[Any] = ...): ...
def __iter__(self) -> Iterator[Union[bytes, str]]: ... def __iter__(self) -> Iterator[Union[bytes, str]]: ...
def __enter__(self) -> File: ... def __enter__(self) -> File: ...
@@ -23,7 +24,7 @@ class File(FileProxyMixin):
def close(self) -> None: ... def close(self) -> None: ...
class ContentFile(File): class ContentFile(File):
file: _io.StringIO file: StringIO
name: None name: None
size: Any = ... size: Any = ...
def __init__( def __init__(

View File

@@ -1,76 +1,110 @@
from typing import Any, Callable, List, Optional, Tuple, Type, Union from typing import Any, Callable, List, Optional, Tuple, Type, Union, Generic, TypeVar
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.expressions import F from django.db.models.expressions import F
from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.fields.related import ForeignObject from django.db.models.fields.related import ForeignObject, RelatedField, OneToOneField
from django.db.models.fields.reverse_related import ManyToManyRel, OneToOneRel from django.db.models.fields.reverse_related import ManyToManyRel, OneToOneRel
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
_T = TypeVar('_T')
class ForwardManyToOneDescriptor: class ForwardManyToOneDescriptor:
RelatedObjectDoesNotExist: Type[django.core.exceptions.ObjectDoesNotExist] RelatedObjectDoesNotExist: Type[ObjectDoesNotExist]
field: django.db.models.fields.related.ForeignObject = ... field: ForeignObject = ...
def __init__(self, field_with_rel: ForeignObject) -> None: ... def __init__(self, field_with_rel: ForeignObject) -> None: ...
def RelatedObjectDoesNotExist(self) -> Type[ObjectDoesNotExist]: ...
def is_cached(self, instance: Model) -> bool: ... def is_cached(self, instance: Model) -> bool: ...
def get_queryset(self, **hints: Any) -> QuerySet: ... def get_queryset(self, **hints: Any) -> QuerySet: ...
def get_prefetch_queryset( def get_prefetch_queryset(
self, instances: List[Model], queryset: Optional[QuerySet] = ... self, instances: List[Model], queryset: Optional[QuerySet] = ...
) -> Tuple[QuerySet, Callable, Callable, bool, str, bool]: ... ) -> Tuple[QuerySet, Callable, Callable, bool, str, bool]: ...
def get_object(self, instance: Model) -> Model: ... def get_object(self, instance: Model) -> Model: ...
def __get__( def __get__(
self, instance: Optional[Model], cls: Type[Model] = ... self, instance: Optional[Model], cls: Type[Model] = ...
) -> Optional[Union[Model, ForwardManyToOneDescriptor]]: ... ) -> Optional[Union[Model, ForwardManyToOneDescriptor]]: ...
def __set__( def __set__(
self, instance: Model, value: Optional[Union[Model, F]] self, instance: Model, value: Optional[Union[Model, F]]
) -> None: ... ) -> None: ...
def __reduce__(self) -> Tuple[Callable, Tuple[Type[Model], str]]: ... def __reduce__(self) -> Tuple[Callable, Tuple[Type[Model], str]]: ...
class ForwardOneToOneDescriptor(ForwardManyToOneDescriptor): class ForwardOneToOneDescriptor(ForwardManyToOneDescriptor):
RelatedObjectDoesNotExist: Type[django.core.exceptions.ObjectDoesNotExist] RelatedObjectDoesNotExist: Type[ObjectDoesNotExist]
field: django.db.models.fields.related.OneToOneField field: OneToOneField
def get_object(self, instance: Model) -> Model: ... def get_object(self, instance: Model) -> Model: ...
def __set__(self, instance: Model, value: Optional[Model]) -> None: ... def __set__(self, instance: Model, value: Optional[Model]) -> None: ...
class ReverseOneToOneDescriptor: class ReverseOneToOneDescriptor:
RelatedObjectDoesNotExist: Type[django.core.exceptions.ObjectDoesNotExist] RelatedObjectDoesNotExist: Type[ObjectDoesNotExist]
related: django.db.models.fields.reverse_related.OneToOneRel = ... related: OneToOneRel = ...
def __init__(self, related: OneToOneRel) -> None: ... def __init__(self, related: OneToOneRel) -> None: ...
def RelatedObjectDoesNotExist(self) -> Type[ObjectDoesNotExist]: ...
def is_cached(self, instance: Model) -> bool: ... def is_cached(self, instance: Model) -> bool: ...
def get_queryset(self, **hints: Any) -> QuerySet: ... def get_queryset(self, **hints: Any) -> QuerySet: ...
def get_prefetch_queryset( def get_prefetch_queryset(
self, instances: List[Model], queryset: Optional[QuerySet] = ... self, instances: List[Model], queryset: Optional[QuerySet] = ...
) -> Tuple[QuerySet, Callable, Callable, bool, str, bool]: ... ) -> Tuple[QuerySet, Callable, Callable, bool, str, bool]: ...
def __get__( def __get__(
self, instance: Optional[Model], cls: Type[Model] = ... self, instance: Optional[Model], cls: Type[Model] = ...
) -> Union[Model, ReverseOneToOneDescriptor]: ... ) -> Union[Model, ReverseOneToOneDescriptor]: ...
def __set__(self, instance: Model, value: Optional[Model]) -> None: ... def __set__(self, instance: Model, value: Optional[Model]) -> None: ...
def __reduce__(self) -> Tuple[Callable, Tuple[Type[Model], str]]: ... def __reduce__(self) -> Tuple[Callable, Tuple[Type[Model], str]]: ...
class ReverseManyToOneDescriptor: class ReverseManyToOneDescriptor:
rel: django.db.models.fields.mixins.FieldCacheMixin = ... rel: FieldCacheMixin = ...
field: django.db.models.fields.mixins.FieldCacheMixin = ... field: FieldCacheMixin = ...
def __init__(self, rel: FieldCacheMixin) -> None: ... def __init__(self, rel: FieldCacheMixin) -> None: ...
def related_manager_cls(self): ... def related_manager_cls(self): ...
def __get__( def __get__(
self, instance: Optional[Model], cls: Type[Model] = ... self, instance: Optional[Model], cls: Type[Model] = ...
) -> ReverseManyToOneDescriptor: ... ) -> ReverseManyToOneDescriptor: ...
def __set__(self, instance: Model, value: List[Model]) -> Any: ... def __set__(self, instance: Model, value: List[Model]) -> Any: ...
def create_reverse_many_to_one_manager(superclass: Any, rel: Any): ... def create_reverse_many_to_one_manager(superclass: Any, rel: Any): ...
class ManyToManyDescriptor(ReverseManyToOneDescriptor): class ManyToManyDescriptor(ReverseManyToOneDescriptor):
field: django.db.models.fields.related.RelatedField field: RelatedField
rel: django.db.models.fields.reverse_related.ManyToManyRel rel: ManyToManyRel
reverse: bool = ... reverse: bool = ...
def __init__(self, rel: ManyToManyRel, reverse: bool = ...) -> None: ... def __init__(self, rel: ManyToManyRel, reverse: bool = ...) -> None: ...
@property @property
def through(self) -> Type[Model]: ... def through(self) -> Type[Model]: ...
def related_manager_cls(self): ... def related_manager_cls(self): ...
class _ForwardManyToManyManager(Generic[_T]):
def all(self) -> QuerySet: ...
def create_forward_many_to_many_manager( def create_forward_many_to_many_manager(
superclass: Any, rel: Any, reverse: Any superclass: Any, rel: Any, reverse: Any
): ... ) -> _ForwardManyToManyManager: ...

View File

@@ -138,12 +138,7 @@ class QuerySet(Generic[_T]):
def get( def get(
self, *args: Any, **kwargs: Any self, *args: Any, **kwargs: Any
) -> Union[ ) -> _T: ...
Dict[str, Union[date, Decimal, float, str]],
Tuple[Union[Decimal, str]],
Model,
str,
]: ...
def create(self, **kwargs: Any) -> _T: ... def create(self, **kwargs: Any) -> _T: ...

View File

@@ -1,5 +1,6 @@
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
from django.core.cache import BaseCache
from django.http.request import HttpRequest from django.http.request import HttpRequest
from django.http.response import HttpResponse, HttpResponseBase from django.http.response import HttpResponse, HttpResponseBase
from django.utils.deprecation import MiddlewareMixin from django.utils.deprecation import MiddlewareMixin
@@ -9,32 +10,39 @@ class UpdateCacheMiddleware(MiddlewareMixin):
cache_timeout: float = ... cache_timeout: float = ...
key_prefix: str = ... key_prefix: str = ...
cache_alias: str = ... cache_alias: str = ...
cache: django.core.cache.backends.base.BaseCache = ... cache: BaseCache = ...
get_response: Optional[Callable] = ... get_response: Optional[Callable] = ...
def __init__(self, get_response: Optional[Callable] = ...) -> None: ... def __init__(self, get_response: Optional[Callable] = ...) -> None: ...
def process_response( def process_response(
self, request: HttpRequest, response: Union[HttpResponseBase, str] self, request: HttpRequest, response: Union[HttpResponseBase, str]
) -> Union[HttpResponseBase, str]: ... ) -> Union[HttpResponseBase, str]: ...
class FetchFromCacheMiddleware(MiddlewareMixin): class FetchFromCacheMiddleware(MiddlewareMixin):
key_prefix: str = ... key_prefix: str = ...
cache_alias: str = ... cache_alias: str = ...
cache: django.core.cache.backends.base.BaseCache = ... cache: BaseCache = ...
get_response: Optional[Callable] = ... get_response: Optional[Callable] = ...
def __init__(self, get_response: Optional[Callable] = ...) -> None: ... def __init__(self, get_response: Optional[Callable] = ...) -> None: ...
def process_request( def process_request(
self, request: HttpRequest self, request: HttpRequest
) -> Optional[HttpResponse]: ... ) -> Optional[HttpResponse]: ...
class CacheMiddleware(UpdateCacheMiddleware, FetchFromCacheMiddleware): class CacheMiddleware(UpdateCacheMiddleware, FetchFromCacheMiddleware):
get_response: None = ... get_response: None = ...
key_prefix: str = ... key_prefix: str = ...
cache_alias: str = ... cache_alias: str = ...
cache_timeout: float = ... cache_timeout: float = ...
cache: django.core.cache.backends.locmem.LocMemCache = ... cache: BaseCache = ...
def __init__( def __init__(
self, self,
get_response: None = ..., get_response: None = ...,
cache_timeout: Optional[float] = ..., cache_timeout: Optional[float] = ...,
**kwargs: Any **kwargs: Any
) -> None: ... ) -> None: ...

View File

@@ -3,3 +3,7 @@ from .testcases import (
TransactionTestCase as TransactionTestCase, TransactionTestCase as TransactionTestCase,
SimpleTestCase as SimpleTestCase SimpleTestCase as SimpleTestCase
) )
from .utils import (
override_settings as override_settings
)

View File

@@ -16,7 +16,8 @@ def lookup_django_model(mypy_api: TypeChecker, fullname: str) -> SymbolTableNode
try: try:
return mypy_api.modules[module].names[model_name] return mypy_api.modules[module].names[model_name]
except KeyError: except KeyError:
return mypy_api.modules['django.db.models'].names['Model'] return mypy_api.lookup_qualified('typing.Any')
# return mypy_api.modules['typing'].names['Any']
def get_app_model(model_name: str) -> str: def get_app_model(model_name: str) -> str:

View File

@@ -1,77 +0,0 @@
from typing import Callable, Optional
from mypy.nodes import AssignmentStmt, CallExpr, RefExpr, StrExpr
from mypy.plugin import Plugin, ClassDefContext
from mypy_django_plugin.helpers import get_app_model
from mypy_django_plugin.model_classes import DjangoModelsRegistry
# fields which real type is inside to= expression
REFERENCING_DB_FIELDS = {
'django.db.models.fields.related.ForeignKey',
'django.db.models.fields.related.OneToOneField'
}
def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None:
to_arg_value = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(to_arg_value, StrExpr):
referred_model_fullname = get_app_model(to_arg_value.value)
else:
referred_model_fullname = to_arg_value.fullname
rvalue.callee.node.metadata['base'] = referred_model_fullname
class CollectModelsInformation(object):
def __init__(self, model_registry: DjangoModelsRegistry):
self.model_registry = model_registry
def __call__(self, model_definition: ClassDefContext) -> None:
self.model_registry.base_models.add(model_definition.cls.fullname)
for member in model_definition.cls.defs.body:
if isinstance(member, AssignmentStmt):
if len(member.lvalues) > 1:
return None
arg_name = member.lvalues[0].name
arg_name_as_id = arg_name + '_id'
rvalue = member.rvalue
if isinstance(rvalue, CallExpr):
if not isinstance(rvalue.callee, RefExpr):
return None
if rvalue.callee.fullname in REFERENCING_DB_FIELDS:
if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey':
model_definition.cls.info.names[arg_name_as_id] = \
model_definition.api.lookup_fully_qualified('builtins.int')
if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField':
if 'related_name' in rvalue.arg_names:
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
if isinstance(referred_to_model, StrExpr):
referred_model_fullname = get_app_model(referred_to_model.value)
else:
referred_model_fullname = referred_to_model.fullname
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
referred_model.node.names[related_arg_value] = \
model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname)
return save_referred_to_model_in_metadata(rvalue)
class BaseDjangoModelsPlugin(Plugin):
model_registry = DjangoModelsRegistry()
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in self.model_registry:
return CollectModelsInformation(self.model_registry)
return None

View File

@@ -0,0 +1,164 @@
import dataclasses
from mypy import types, nodes
from mypy.nodes import CallExpr, StrExpr, AssignmentStmt, RefExpr
from mypy.plugin import AttributeContext, ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.types import Type, Instance, AnyType, TypeOfAny
from mypy_django_plugin.helpers import lookup_django_model, get_app_model
from mypy_django_plugin.model_classes import DjangoModelsRegistry
# mapping between field types and plain python types
DB_FIELDS_TO_TYPES = {
'django.db.models.fields.CharField': 'builtins.str',
'django.db.models.fields.TextField': 'builtins.str',
'django.db.models.fields.BooleanField': 'builtins.bool',
# 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]',
'django.db.models.fields.IntegerField': 'builtins.int',
'django.db.models.fields.AutoField': 'builtins.int',
'django.db.models.fields.FloatField': 'builtins.float',
'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict',
'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable'
}
# def get_queryset_of(type_fullname: str):
# return model_definition.api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
@dataclasses.dataclass
class DjangoPluginApi(object):
mypy_api: SemanticAnalyzerPluginInterface
def get_queryset_of(self, type_fullname: str = 'django.db.models.base.Model') -> types.Type:
queryset_sym = self.mypy_api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
if not queryset_sym:
return AnyType(TypeOfAny.from_error)
generic_arg = self.mypy_api.lookup_fully_qualified_or_none(type_fullname)
if not generic_arg:
return Instance(queryset_sym.node, [AnyType(TypeOfAny.from_error)])
return Instance(queryset_sym.node, [Instance(generic_arg.node, [])])
def generate_related_manager_assignment_stmt(self,
related_mngr_name: str,
queryset_argument_type_fullname: str) -> nodes.AssignmentStmt:
rvalue = nodes.TempNode(AnyType(TypeOfAny.special_form))
assignment = nodes.AssignmentStmt(lvalues=[nodes.NameExpr(related_mngr_name)],
rvalue=rvalue,
new_syntax=True,
type=self.get_queryset_of(queryset_argument_type_fullname))
return assignment
class DetermineFieldPythonTypeCallback(object):
def __init__(self, models_registry: DjangoModelsRegistry):
self.models_registry = models_registry
def __call__(self, attr_context: AttributeContext) -> Type:
default_attr_type = attr_context.default_attr_type
if isinstance(default_attr_type, Instance):
attr_type_fullname = default_attr_type.type.fullname()
if attr_type_fullname in DB_FIELDS_TO_TYPES:
return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname])
# if 'base' in default_attr_type.type.metadata:
# referred_base_model = default_attr_type.type.metadata['base']
if 'members' in attr_context.type.type.metadata:
arg_name = attr_context.context.name
if arg_name in attr_context.type.type.metadata['members']:
referred_base_model = attr_context.type.type.metadata['members'][arg_name]
typ = lookup_django_model(attr_context.api, referred_base_model)
try:
return Instance(typ.node, [])
except AssertionError as e:
return typ.type
return default_attr_type
# fields which real type is inside to= expression
REFERENCING_DB_FIELDS = {
'django.db.models.fields.related.ForeignKey',
'django.db.models.fields.related.OneToOneField'
}
def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None:
to_arg_value = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(to_arg_value, StrExpr):
referred_model_fullname = get_app_model(to_arg_value.value)
else:
referred_model_fullname = to_arg_value.fullname
rvalue.callee.node.metadata['base'] = referred_model_fullname
class CollectModelsInformationCallback(object):
def __init__(self, model_registry: DjangoModelsRegistry):
self.model_registry = model_registry
def __call__(self, model_definition: ClassDefContext) -> None:
self.model_registry.base_models.add(model_definition.cls.fullname)
plugin_api = DjangoPluginApi(mypy_api=model_definition.api)
for member in model_definition.cls.defs.body:
if isinstance(member, AssignmentStmt):
if len(member.lvalues) > 1:
return None
arg_name = member.lvalues[0].name
arg_name_as_id = arg_name + '_id'
rvalue = member.rvalue
if isinstance(rvalue, CallExpr):
if not isinstance(rvalue.callee, RefExpr):
return None
if rvalue.callee.fullname in REFERENCING_DB_FIELDS:
if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey':
model_definition.cls.info.names[arg_name_as_id] = \
model_definition.api.lookup_fully_qualified('builtins.int')
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(referred_to_model, StrExpr):
referred_model_fullname = get_app_model(referred_to_model.value)
else:
referred_model_fullname = referred_to_model.fullname
rvalue.callee.node.metadata['base'] = referred_model_fullname
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
if 'related_name' in rvalue.arg_names:
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
referred_model_class_def = referred_model.node.defn # type: nodes.ClassDef
referred_model_class_def.defs.body.append(
plugin_api.generate_related_manager_assignment_stmt(related_arg_value,
model_definition.cls.fullname))
if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField':
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(referred_to_model, StrExpr):
referred_model_fullname = get_app_model(referred_to_model.value)
else:
referred_model_fullname = referred_to_model.fullname
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
if 'related_name' in rvalue.arg_names:
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
referred_model.node.names[related_arg_value] = \
model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname)
rvalue.callee.node.metadata['base'] = referred_model_fullname
rvalue.callee.node.metadata['name'] = arg_name
if 'members' not in model_definition.cls.info.metadata:
model_definition.cls.info.metadata['members'] = {}
model_definition.cls.info.metadata['members'][arg_name] = rvalue.callee.node.metadata.get('base', None)

View File

@@ -1,54 +1,24 @@
from typing import Optional, Callable from typing import Optional, Callable, Type
from mypy.plugin import AttributeContext from mypy.plugin import Plugin, ClassDefContext, AttributeContext
from mypy.types import Type, Instance
from mypy_django_plugin.helpers import lookup_django_model
from mypy_django_plugin.model_classes import DjangoModelsRegistry from mypy_django_plugin.model_classes import DjangoModelsRegistry
from mypy_django_plugin.plugins.base import BaseDjangoModelsPlugin from mypy_django_plugin.plugins.callbacks import CollectModelsInformationCallback, DetermineFieldPythonTypeCallback
# mapping between field types and plain python types
DB_FIELDS_TO_TYPES = {
'django.db.models.fields.CharField': 'builtins.str',
'django.db.models.fields.TextField': 'builtins.str',
'django.db.models.fields.BooleanField': 'builtins.bool',
# 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]',
'django.db.models.fields.IntegerField': 'builtins.int',
'django.db.models.fields.AutoField': 'builtins.int',
'django.db.models.fields.FloatField': 'builtins.float',
'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict',
'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable'
}
class DetermineFieldPythonTypeCallback(object): class FieldToPythonTypePlugin(Plugin):
def __init__(self, models_registry: DjangoModelsRegistry): model_registry = DjangoModelsRegistry()
self.models_registry = models_registry
def __call__(self, attr_context: AttributeContext) -> Type: def get_base_class_hook(self, fullname: str
default_attr_type = attr_context.default_attr_type ) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in self.model_registry:
return CollectModelsInformationCallback(self.model_registry)
if isinstance(default_attr_type, Instance): return None
attr_type_fullname = default_attr_type.type.fullname()
if attr_type_fullname in DB_FIELDS_TO_TYPES:
return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname])
if 'base' in default_attr_type.type.metadata:
referred_base_model = default_attr_type.type.metadata['base']
try:
node = lookup_django_model(attr_context.api, referred_base_model).node
return Instance(node, [])
except AssertionError as e:
print(e)
print('name to lookup:', referred_base_model)
pass
return default_attr_type
class FieldToPythonTypePlugin(BaseDjangoModelsPlugin):
def get_attribute_hook(self, fullname: str def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]: ) -> Optional[Callable[[AttributeContext], Type]]:
# print(fullname)
classname, _, attrname = fullname.rpartition('.') classname, _, attrname = fullname.rpartition('.')
if classname and classname in self.model_registry: if classname and classname in self.model_registry:
return DetermineFieldPythonTypeCallback(self.model_registry) return DetermineFieldPythonTypeCallback(self.model_registry)