Files
django-stubs/mypy_django_plugin/plugins/callbacks.py
2018-10-12 01:56:25 +03:00

165 lines
7.9 KiB
Python

import dataclasses
from mypy import types, nodes
from mypy.nodes import CallExpr, StrExpr, AssignmentStmt, RefExpr
from mypy.plugin import AttributeContext, ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.types import Type, Instance, AnyType, TypeOfAny
from mypy_django_plugin.helpers import lookup_django_model, get_app_model
from mypy_django_plugin.model_classes import DjangoModelsRegistry
# mapping between field types and plain python types
DB_FIELDS_TO_TYPES = {
'django.db.models.fields.CharField': 'builtins.str',
'django.db.models.fields.TextField': 'builtins.str',
'django.db.models.fields.BooleanField': 'builtins.bool',
# 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]',
'django.db.models.fields.IntegerField': 'builtins.int',
'django.db.models.fields.AutoField': 'builtins.int',
'django.db.models.fields.FloatField': 'builtins.float',
'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict',
'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable'
}
# def get_queryset_of(type_fullname: str):
# return model_definition.api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
@dataclasses.dataclass
class DjangoPluginApi(object):
mypy_api: SemanticAnalyzerPluginInterface
def get_queryset_of(self, type_fullname: str = 'django.db.models.base.Model') -> types.Type:
queryset_sym = self.mypy_api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
if not queryset_sym:
return AnyType(TypeOfAny.from_error)
generic_arg = self.mypy_api.lookup_fully_qualified_or_none(type_fullname)
if not generic_arg:
return Instance(queryset_sym.node, [AnyType(TypeOfAny.from_error)])
return Instance(queryset_sym.node, [Instance(generic_arg.node, [])])
def generate_related_manager_assignment_stmt(self,
related_mngr_name: str,
queryset_argument_type_fullname: str) -> nodes.AssignmentStmt:
rvalue = nodes.TempNode(AnyType(TypeOfAny.special_form))
assignment = nodes.AssignmentStmt(lvalues=[nodes.NameExpr(related_mngr_name)],
rvalue=rvalue,
new_syntax=True,
type=self.get_queryset_of(queryset_argument_type_fullname))
return assignment
class DetermineFieldPythonTypeCallback(object):
def __init__(self, models_registry: DjangoModelsRegistry):
self.models_registry = models_registry
def __call__(self, attr_context: AttributeContext) -> Type:
default_attr_type = attr_context.default_attr_type
if isinstance(default_attr_type, Instance):
attr_type_fullname = default_attr_type.type.fullname()
if attr_type_fullname in DB_FIELDS_TO_TYPES:
return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname])
# if 'base' in default_attr_type.type.metadata:
# referred_base_model = default_attr_type.type.metadata['base']
if 'members' in attr_context.type.type.metadata:
arg_name = attr_context.context.name
if arg_name in attr_context.type.type.metadata['members']:
referred_base_model = attr_context.type.type.metadata['members'][arg_name]
typ = lookup_django_model(attr_context.api, referred_base_model)
try:
return Instance(typ.node, [])
except AssertionError as e:
return typ.type
return default_attr_type
# fields which real type is inside to= expression
REFERENCING_DB_FIELDS = {
'django.db.models.fields.related.ForeignKey',
'django.db.models.fields.related.OneToOneField'
}
def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None:
to_arg_value = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(to_arg_value, StrExpr):
referred_model_fullname = get_app_model(to_arg_value.value)
else:
referred_model_fullname = to_arg_value.fullname
rvalue.callee.node.metadata['base'] = referred_model_fullname
class CollectModelsInformationCallback(object):
def __init__(self, model_registry: DjangoModelsRegistry):
self.model_registry = model_registry
def __call__(self, model_definition: ClassDefContext) -> None:
self.model_registry.base_models.add(model_definition.cls.fullname)
plugin_api = DjangoPluginApi(mypy_api=model_definition.api)
for member in model_definition.cls.defs.body:
if isinstance(member, AssignmentStmt):
if len(member.lvalues) > 1:
return None
arg_name = member.lvalues[0].name
arg_name_as_id = arg_name + '_id'
rvalue = member.rvalue
if isinstance(rvalue, CallExpr):
if not isinstance(rvalue.callee, RefExpr):
return None
if rvalue.callee.fullname in REFERENCING_DB_FIELDS:
if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey':
model_definition.cls.info.names[arg_name_as_id] = \
model_definition.api.lookup_fully_qualified('builtins.int')
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(referred_to_model, StrExpr):
referred_model_fullname = get_app_model(referred_to_model.value)
else:
referred_model_fullname = referred_to_model.fullname
rvalue.callee.node.metadata['base'] = referred_model_fullname
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
if 'related_name' in rvalue.arg_names:
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
referred_model_class_def = referred_model.node.defn # type: nodes.ClassDef
referred_model_class_def.defs.body.append(
plugin_api.generate_related_manager_assignment_stmt(related_arg_value,
model_definition.cls.fullname))
if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField':
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(referred_to_model, StrExpr):
referred_model_fullname = get_app_model(referred_to_model.value)
else:
referred_model_fullname = referred_to_model.fullname
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
if 'related_name' in rvalue.arg_names:
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
referred_model.node.names[related_arg_value] = \
model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname)
rvalue.callee.node.metadata['base'] = referred_model_fullname
rvalue.callee.node.metadata['name'] = arg_name
if 'members' not in model_definition.cls.info.metadata:
model_definition.cls.info.metadata['members'] = {}
model_definition.cls.info.metadata['members'][arg_name] = rvalue.callee.node.metadata.get('base', None)