add support for FIELDNAME_id for foreignkey/onetoonefield

This commit is contained in:
Maxim Kurnikov
2018-11-13 18:36:46 +03:00
parent dc4f606f63
commit 17be428776
2 changed files with 76 additions and 18 deletions

View File

@@ -1,12 +1,16 @@
from typing import Optional, Callable, cast
from mypy.checker import TypeChecker
from mypy.nodes import Var, MDEF, SymbolTableNode
from mypy.plugin import Plugin, FunctionContext
from mypy.nodes import Var, MDEF, SymbolTableNode, TypeInfo, SymbolTable
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Type, CallableType, Instance
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
def extract_to_value_type(ctx: FunctionContext) -> Optional[Type]:
def extract_to_value_type(ctx: FunctionContext) -> Optional[Instance]:
assert 'to' in ctx.context.arg_names
to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0]
if not isinstance(to_arg_value, CallableType):
@@ -19,7 +23,22 @@ def extract_related_name_value(ctx: FunctionContext) -> str:
return ctx.context.args[ctx.context.arg_names.index('related_name')].value
def create_new_symtable_node_for_class_member(name: str, instance: Instance) -> SymbolTableNode:
new_var = Var(name, instance)
new_var.info = instance.type
return SymbolTableNode(MDEF, new_var, plugin_generated=True)
def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instance: Instance) -> None:
klass_typeinfo.names[name] = create_new_symtable_node_for_class_member(name,
instance=new_member_instance)
def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
outer_class_info = api.tscope.classes[-1]
if 'related_name' not in ctx.context.arg_names:
return ctx.default_return_type
@@ -28,15 +47,10 @@ def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
return ctx.default_return_type
related_name = extract_related_name_value(ctx)
outer_class_info = ctx.api.tscope.classes[-1]
queryset_type = ctx.api.named_generic_type('django.db.models.QuerySet',
args=[Instance(outer_class_info, [])])
related_var = Var(related_name,
queryset_type)
related_var.info = queryset_type.type
referred_to.type.names[related_name] = SymbolTableNode(MDEF, related_var,
plugin_generated=True)
queryset_type = api.named_generic_type('django.db.models.QuerySet',
args=[Instance(outer_class_info, [])])
add_new_class_member(referred_to.type,
related_name, queryset_type)
return ctx.default_return_type
@@ -52,15 +66,27 @@ def set_related_name_instance_for_onetoonefield(ctx: FunctionContext) -> Type:
outer_class_info = ctx.api.tscope.classes[-1]
api = cast(TypeChecker, ctx.api)
related_instance_type = api.named_type(outer_class_info.fullname())
related_var = Var(related_name, related_instance_type)
related_var.info = related_instance_type.type
referred_to.type.names[related_name] = SymbolTableNode(MDEF, related_var,
plugin_generated=True)
add_new_class_member(referred_to.type, related_name,
new_member_instance=api.named_type(outer_class_info.fullname()))
return ctx.default_return_type
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
api = ctx.api
new_symtable_nodes = SymbolTable()
for (name, symtable_node), assignment_stmt in zip(ctx.cls.info.names.items(), ctx.cls.defs.body):
rvalue_callee = assignment_stmt.rvalue.callee
if rvalue_callee.fullname in {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME}:
name += '_id'
new_node = create_new_symtable_node_for_class_member(name,
instance=api.named_type('__builtins__.int'))
new_symtable_nodes[name] = new_node
for name, node in new_symtable_nodes.items():
ctx.cls.info.names[name] = node
class RelatedFieldsPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
@@ -72,6 +98,12 @@ class RelatedFieldsPlugin(Plugin):
return None
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname == 'django.db.models.base.Model':
return set_fieldname_attrs_for_related_fields
return None
def plugin(version):
return RelatedFieldsPlugin

View File

@@ -15,6 +15,20 @@ publisher = Publisher()
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
[out]
[case testEveryForeignKeyCreatesFieldNameWithIdAttribute]
from django.db import models
class Publisher(models.Model):
pass
class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE,
related_name='books')
book = Book()
reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
[out]
[case testOneToOneField]
from django.db import models
@@ -30,3 +44,15 @@ reveal_type(profile.user) # E: Revealed type is 'main.User*'
user = User()
reveal_type(user.profile) # E: Revealed type is 'main.Profile'
[case testOneToOneFieldAttrWithUnderscoreID]
from django.db import models
class User(models.Model):
pass
class Profile(models.Model):
user = models.OneToOneField(to=User, on_delete=models.CASCADE, related_name='profile')
profile = Profile()
reveal_type(profile.user_id) # E: Revealed type is 'builtins.int'
[out]