Files
django-stubs/mypy_django_plugin/plugins/related_fields.py

116 lines
4.4 KiB
Python

import typing
from typing import Optional, cast
from django.conf import Settings
from mypy.checker import TypeChecker
from mypy.nodes import SymbolTable, MDEF, AssignmentStmt
from mypy.plugin import FunctionContext, ClassDefContext
from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny
from mypy_django_plugin import helpers
def extract_related_name_value(ctx: FunctionContext) -> str:
return ctx.context.args[ctx.arg_names.index('related_name')].value
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
return Instance(instance.type, args=new_typevars)
def fill_typevars_with_any(instance: Instance) -> Type:
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
if 'to' not in ctx.arg_names:
# shouldn't happen, invalid code
ctx.api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
context=ctx.context)
return None
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
if not isinstance(arg_type, CallableType):
ctx.api.msg.warn(f'to= parameter type {arg_type.__class__.__name__} is not supported',
context=ctx.context)
return None
referred_to_type = arg_type.ret_type
for base in referred_to_type.type.bases:
if base.type.fullname() == helpers.MODEL_CLASS_FULLNAME:
break
else:
ctx.api.msg.fail(f'to= parameter value must be '
f'a subclass of {helpers.MODEL_CLASS_FULLNAME}',
context=ctx.context)
return None
return referred_to_type
class ForeignKeyHook(object):
def __init__(self, settings: Settings):
self.settings = settings
def __call__(self, ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
outer_class_info = api.tscope.classes[-1]
referred_to_type = get_valid_to_value_or_none(ctx)
if referred_to_type is None:
return fill_typevars_with_any(ctx.default_return_type)
if 'related_name' in ctx.arg_names:
related_name = extract_related_name_value(ctx)
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
args=[Instance(outer_class_info, [])])
sym = helpers.create_new_symtable_node(related_name, MDEF,
instance=queryset_type)
referred_to_type.type.names[related_name] = sym
return reparametrize_with(ctx.default_return_type, [referred_to_type])
class OneToOneFieldHook(object):
def __init__(self, settings: Optional[Settings]):
self.settings = settings
def __call__(self, ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
outer_class_info = api.tscope.classes[-1]
referred_to_type = get_valid_to_value_or_none(ctx)
if referred_to_type is None:
return fill_typevars_with_any(ctx.default_return_type)
if 'related_name' in ctx.arg_names:
related_name = extract_related_name_value(ctx)
sym = helpers.create_new_symtable_node(related_name, MDEF,
instance=Instance(outer_class_info, []))
referred_to_type.type.names[related_name] = sym
return reparametrize_with(ctx.default_return_type, [referred_to_type])
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
api = ctx.api
new_symtable_nodes = SymbolTable()
for (name, symtable_node), stmt in zip(ctx.cls.info.names.items(), ctx.cls.defs.body):
if not isinstance(stmt, AssignmentStmt):
continue
if not hasattr(stmt.rvalue, 'callee'):
continue
rvalue_callee = stmt.rvalue.callee
if rvalue_callee.fullname in {helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME}:
name += '_id'
new_node = helpers.create_new_symtable_node(name,
kind=MDEF,
instance=api.named_type('__builtins__.int'))
new_symtable_nodes[name] = new_node
for name, node in new_symtable_nodes.items():
ctx.cls.info.names[name] = node