add properly typed FOREIGN_KEY_FIELD_NAME_id fields to models

This commit is contained in:
Maxim Kurnikov
2019-02-13 21:05:02 +03:00
parent 82de0a8791
commit 26a80a8279
5 changed files with 153 additions and 61 deletions

View File

@@ -4,8 +4,8 @@ from typing import Callable, Dict, Optional, cast
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo
from mypy.options import Options
from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import Instance, Type, TypeType
from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin, AttributeContext
from mypy.types import Instance, Type, TypeType, AnyType, TypeOfAny
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.config import Config
@@ -85,6 +85,27 @@ def return_user_model_hook(ctx: FunctionContext) -> Type:
return TypeType(Instance(model_info, []))
def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: AttributeContext) -> Type:
if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'):
return ctx.default_attr_type
if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(helpers.MODEL_CLASS_FULLNAME):
return ctx.default_attr_type
field_name = ctx.context.name.split('_')[0]
sym = ctx.type.type.get(field_name)
if sym and isinstance(sym.type, Instance) and len(sym.type.args) > 0:
to_arg = sym.type.args[0]
if isinstance(to_arg, AnyType):
return AnyType(TypeOfAny.special_form)
model_type: TypeInfo = to_arg.type
primary_key_type = helpers.extract_primary_key_type_for_get(model_type)
if primary_key_type:
return primary_key_type
return ctx.default_attr_type
class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
@@ -186,6 +207,14 @@ class DjangoPlugin(Plugin):
return None
def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
# sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME)
# if sym and isinstance(sym.node, TypeInfo):
# if fullname.rpartition('.')[-1] in helpers.get_related_field_primary_key_names(sym.node):
return extract_and_return_primary_key_of_bound_related_field_parameter
def plugin(version):
return DjangoPlugin