mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-08 04:54:48 +08:00
116 lines
4.4 KiB
Python
116 lines
4.4 KiB
Python
import typing
|
|
from typing import Optional, cast
|
|
|
|
from django.conf import Settings
|
|
from mypy.checker import TypeChecker
|
|
from mypy.nodes import SymbolTable, MDEF, AssignmentStmt
|
|
from mypy.plugin import FunctionContext, ClassDefContext
|
|
from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny
|
|
|
|
from mypy_django_plugin import helpers
|
|
|
|
|
|
def extract_related_name_value(ctx: FunctionContext) -> str:
|
|
return ctx.context.args[ctx.arg_names.index('related_name')].value
|
|
|
|
|
|
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
|
|
return Instance(instance.type, args=new_typevars)
|
|
|
|
|
|
def fill_typevars_with_any(instance: Instance) -> Type:
|
|
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
|
|
|
|
|
|
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
|
|
if 'to' not in ctx.arg_names:
|
|
# shouldn't happen, invalid code
|
|
ctx.api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
|
|
context=ctx.context)
|
|
return None
|
|
|
|
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
|
|
if not isinstance(arg_type, CallableType):
|
|
ctx.api.msg.warn(f'to= parameter type {arg_type.__class__.__name__} is not supported',
|
|
context=ctx.context)
|
|
return None
|
|
|
|
referred_to_type = arg_type.ret_type
|
|
for base in referred_to_type.type.bases:
|
|
if base.type.fullname() == helpers.MODEL_CLASS_FULLNAME:
|
|
break
|
|
else:
|
|
ctx.api.msg.fail(f'to= parameter value must be '
|
|
f'a subclass of {helpers.MODEL_CLASS_FULLNAME}',
|
|
context=ctx.context)
|
|
return None
|
|
|
|
return referred_to_type
|
|
|
|
|
|
class ForeignKeyHook(object):
|
|
def __init__(self, settings: Settings):
|
|
self.settings = settings
|
|
|
|
def __call__(self, ctx: FunctionContext) -> Type:
|
|
api = cast(TypeChecker, ctx.api)
|
|
outer_class_info = api.tscope.classes[-1]
|
|
|
|
referred_to_type = get_valid_to_value_or_none(ctx)
|
|
if referred_to_type is None:
|
|
return fill_typevars_with_any(ctx.default_return_type)
|
|
|
|
if 'related_name' in ctx.arg_names:
|
|
related_name = extract_related_name_value(ctx)
|
|
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
|
|
args=[Instance(outer_class_info, [])])
|
|
sym = helpers.create_new_symtable_node(related_name, MDEF,
|
|
instance=queryset_type)
|
|
referred_to_type.type.names[related_name] = sym
|
|
|
|
return reparametrize_with(ctx.default_return_type, [referred_to_type])
|
|
|
|
|
|
class OneToOneFieldHook(object):
|
|
def __init__(self, settings: Optional[Settings]):
|
|
self.settings = settings
|
|
|
|
def __call__(self, ctx: FunctionContext) -> Type:
|
|
api = cast(TypeChecker, ctx.api)
|
|
outer_class_info = api.tscope.classes[-1]
|
|
|
|
referred_to_type = get_valid_to_value_or_none(ctx)
|
|
if referred_to_type is None:
|
|
return fill_typevars_with_any(ctx.default_return_type)
|
|
|
|
if 'related_name' in ctx.arg_names:
|
|
related_name = extract_related_name_value(ctx)
|
|
sym = helpers.create_new_symtable_node(related_name, MDEF,
|
|
instance=Instance(outer_class_info, []))
|
|
referred_to_type.type.names[related_name] = sym
|
|
|
|
return reparametrize_with(ctx.default_return_type, [referred_to_type])
|
|
|
|
|
|
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
|
|
api = ctx.api
|
|
|
|
new_symtable_nodes = SymbolTable()
|
|
for (name, symtable_node), stmt in zip(ctx.cls.info.names.items(), ctx.cls.defs.body):
|
|
if not isinstance(stmt, AssignmentStmt):
|
|
continue
|
|
if not hasattr(stmt.rvalue, 'callee'):
|
|
continue
|
|
|
|
rvalue_callee = stmt.rvalue.callee
|
|
if rvalue_callee.fullname in {helpers.FOREIGN_KEY_FULLNAME,
|
|
helpers.ONETOONE_FIELD_FULLNAME}:
|
|
name += '_id'
|
|
new_node = helpers.create_new_symtable_node(name,
|
|
kind=MDEF,
|
|
instance=api.named_type('__builtins__.int'))
|
|
new_symtable_nodes[name] = new_node
|
|
|
|
for name, node in new_symtable_nodes.items():
|
|
ctx.cls.info.names[name] = node
|