more fixes for django stubs, first attempt for a plugin

This commit is contained in:
Maxim Kurnikov
2018-10-12 01:56:25 +03:00
parent b93f589cff
commit 2cdefc4662
11 changed files with 273 additions and 165 deletions

View File

@@ -1,77 +0,0 @@
from typing import Callable, Optional
from mypy.nodes import AssignmentStmt, CallExpr, RefExpr, StrExpr
from mypy.plugin import Plugin, ClassDefContext
from mypy_django_plugin.helpers import get_app_model
from mypy_django_plugin.model_classes import DjangoModelsRegistry
# fields which real type is inside to= expression
REFERENCING_DB_FIELDS = {
'django.db.models.fields.related.ForeignKey',
'django.db.models.fields.related.OneToOneField'
}
def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None:
to_arg_value = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(to_arg_value, StrExpr):
referred_model_fullname = get_app_model(to_arg_value.value)
else:
referred_model_fullname = to_arg_value.fullname
rvalue.callee.node.metadata['base'] = referred_model_fullname
class CollectModelsInformation(object):
def __init__(self, model_registry: DjangoModelsRegistry):
self.model_registry = model_registry
def __call__(self, model_definition: ClassDefContext) -> None:
self.model_registry.base_models.add(model_definition.cls.fullname)
for member in model_definition.cls.defs.body:
if isinstance(member, AssignmentStmt):
if len(member.lvalues) > 1:
return None
arg_name = member.lvalues[0].name
arg_name_as_id = arg_name + '_id'
rvalue = member.rvalue
if isinstance(rvalue, CallExpr):
if not isinstance(rvalue.callee, RefExpr):
return None
if rvalue.callee.fullname in REFERENCING_DB_FIELDS:
if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey':
model_definition.cls.info.names[arg_name_as_id] = \
model_definition.api.lookup_fully_qualified('builtins.int')
if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField':
if 'related_name' in rvalue.arg_names:
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
if isinstance(referred_to_model, StrExpr):
referred_model_fullname = get_app_model(referred_to_model.value)
else:
referred_model_fullname = referred_to_model.fullname
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
referred_model.node.names[related_arg_value] = \
model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname)
return save_referred_to_model_in_metadata(rvalue)
class BaseDjangoModelsPlugin(Plugin):
model_registry = DjangoModelsRegistry()
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in self.model_registry:
return CollectModelsInformation(self.model_registry)
return None

View File

@@ -0,0 +1,164 @@
import dataclasses
from mypy import types, nodes
from mypy.nodes import CallExpr, StrExpr, AssignmentStmt, RefExpr
from mypy.plugin import AttributeContext, ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.types import Type, Instance, AnyType, TypeOfAny
from mypy_django_plugin.helpers import lookup_django_model, get_app_model
from mypy_django_plugin.model_classes import DjangoModelsRegistry
# mapping between field types and plain python types
DB_FIELDS_TO_TYPES = {
'django.db.models.fields.CharField': 'builtins.str',
'django.db.models.fields.TextField': 'builtins.str',
'django.db.models.fields.BooleanField': 'builtins.bool',
# 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]',
'django.db.models.fields.IntegerField': 'builtins.int',
'django.db.models.fields.AutoField': 'builtins.int',
'django.db.models.fields.FloatField': 'builtins.float',
'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict',
'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable'
}
# def get_queryset_of(type_fullname: str):
# return model_definition.api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
@dataclasses.dataclass
class DjangoPluginApi(object):
mypy_api: SemanticAnalyzerPluginInterface
def get_queryset_of(self, type_fullname: str = 'django.db.models.base.Model') -> types.Type:
queryset_sym = self.mypy_api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
if not queryset_sym:
return AnyType(TypeOfAny.from_error)
generic_arg = self.mypy_api.lookup_fully_qualified_or_none(type_fullname)
if not generic_arg:
return Instance(queryset_sym.node, [AnyType(TypeOfAny.from_error)])
return Instance(queryset_sym.node, [Instance(generic_arg.node, [])])
def generate_related_manager_assignment_stmt(self,
related_mngr_name: str,
queryset_argument_type_fullname: str) -> nodes.AssignmentStmt:
rvalue = nodes.TempNode(AnyType(TypeOfAny.special_form))
assignment = nodes.AssignmentStmt(lvalues=[nodes.NameExpr(related_mngr_name)],
rvalue=rvalue,
new_syntax=True,
type=self.get_queryset_of(queryset_argument_type_fullname))
return assignment
class DetermineFieldPythonTypeCallback(object):
def __init__(self, models_registry: DjangoModelsRegistry):
self.models_registry = models_registry
def __call__(self, attr_context: AttributeContext) -> Type:
default_attr_type = attr_context.default_attr_type
if isinstance(default_attr_type, Instance):
attr_type_fullname = default_attr_type.type.fullname()
if attr_type_fullname in DB_FIELDS_TO_TYPES:
return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname])
# if 'base' in default_attr_type.type.metadata:
# referred_base_model = default_attr_type.type.metadata['base']
if 'members' in attr_context.type.type.metadata:
arg_name = attr_context.context.name
if arg_name in attr_context.type.type.metadata['members']:
referred_base_model = attr_context.type.type.metadata['members'][arg_name]
typ = lookup_django_model(attr_context.api, referred_base_model)
try:
return Instance(typ.node, [])
except AssertionError as e:
return typ.type
return default_attr_type
# fields which real type is inside to= expression
REFERENCING_DB_FIELDS = {
'django.db.models.fields.related.ForeignKey',
'django.db.models.fields.related.OneToOneField'
}
def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None:
to_arg_value = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(to_arg_value, StrExpr):
referred_model_fullname = get_app_model(to_arg_value.value)
else:
referred_model_fullname = to_arg_value.fullname
rvalue.callee.node.metadata['base'] = referred_model_fullname
class CollectModelsInformationCallback(object):
def __init__(self, model_registry: DjangoModelsRegistry):
self.model_registry = model_registry
def __call__(self, model_definition: ClassDefContext) -> None:
self.model_registry.base_models.add(model_definition.cls.fullname)
plugin_api = DjangoPluginApi(mypy_api=model_definition.api)
for member in model_definition.cls.defs.body:
if isinstance(member, AssignmentStmt):
if len(member.lvalues) > 1:
return None
arg_name = member.lvalues[0].name
arg_name_as_id = arg_name + '_id'
rvalue = member.rvalue
if isinstance(rvalue, CallExpr):
if not isinstance(rvalue.callee, RefExpr):
return None
if rvalue.callee.fullname in REFERENCING_DB_FIELDS:
if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey':
model_definition.cls.info.names[arg_name_as_id] = \
model_definition.api.lookup_fully_qualified('builtins.int')
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(referred_to_model, StrExpr):
referred_model_fullname = get_app_model(referred_to_model.value)
else:
referred_model_fullname = referred_to_model.fullname
rvalue.callee.node.metadata['base'] = referred_model_fullname
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
if 'related_name' in rvalue.arg_names:
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
referred_model_class_def = referred_model.node.defn # type: nodes.ClassDef
referred_model_class_def.defs.body.append(
plugin_api.generate_related_manager_assignment_stmt(related_arg_value,
model_definition.cls.fullname))
if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField':
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(referred_to_model, StrExpr):
referred_model_fullname = get_app_model(referred_to_model.value)
else:
referred_model_fullname = referred_to_model.fullname
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
if 'related_name' in rvalue.arg_names:
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
referred_model.node.names[related_arg_value] = \
model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname)
rvalue.callee.node.metadata['base'] = referred_model_fullname
rvalue.callee.node.metadata['name'] = arg_name
if 'members' not in model_definition.cls.info.metadata:
model_definition.cls.info.metadata['members'] = {}
model_definition.cls.info.metadata['members'][arg_name] = rvalue.callee.node.metadata.get('base', None)

View File

@@ -1,54 +1,24 @@
from typing import Optional, Callable
from typing import Optional, Callable, Type
from mypy.plugin import AttributeContext
from mypy.types import Type, Instance
from mypy.plugin import Plugin, ClassDefContext, AttributeContext
from mypy_django_plugin.helpers import lookup_django_model
from mypy_django_plugin.model_classes import DjangoModelsRegistry
from mypy_django_plugin.plugins.base import BaseDjangoModelsPlugin
# mapping between field types and plain python types
DB_FIELDS_TO_TYPES = {
'django.db.models.fields.CharField': 'builtins.str',
'django.db.models.fields.TextField': 'builtins.str',
'django.db.models.fields.BooleanField': 'builtins.bool',
# 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]',
'django.db.models.fields.IntegerField': 'builtins.int',
'django.db.models.fields.AutoField': 'builtins.int',
'django.db.models.fields.FloatField': 'builtins.float',
'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict',
'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable'
}
from mypy_django_plugin.plugins.callbacks import CollectModelsInformationCallback, DetermineFieldPythonTypeCallback
class DetermineFieldPythonTypeCallback(object):
def __init__(self, models_registry: DjangoModelsRegistry):
self.models_registry = models_registry
class FieldToPythonTypePlugin(Plugin):
model_registry = DjangoModelsRegistry()
def __call__(self, attr_context: AttributeContext) -> Type:
default_attr_type = attr_context.default_attr_type
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in self.model_registry:
return CollectModelsInformationCallback(self.model_registry)
if isinstance(default_attr_type, Instance):
attr_type_fullname = default_attr_type.type.fullname()
if attr_type_fullname in DB_FIELDS_TO_TYPES:
return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname])
return None
if 'base' in default_attr_type.type.metadata:
referred_base_model = default_attr_type.type.metadata['base']
try:
node = lookup_django_model(attr_context.api, referred_base_model).node
return Instance(node, [])
except AssertionError as e:
print(e)
print('name to lookup:', referred_base_model)
pass
return default_attr_type
class FieldToPythonTypePlugin(BaseDjangoModelsPlugin):
def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
# print(fullname)
classname, _, attrname = fullname.rpartition('.')
if classname and classname in self.model_registry:
return DetermineFieldPythonTypeCallback(self.model_registry)