mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-07 12:44:29 +08:00
add support for FIELDNAME_id for foreignkey/onetoonefield
This commit is contained in:
@@ -1,12 +1,16 @@
|
|||||||
from typing import Optional, Callable, cast
|
from typing import Optional, Callable, cast
|
||||||
|
|
||||||
from mypy.checker import TypeChecker
|
from mypy.checker import TypeChecker
|
||||||
from mypy.nodes import Var, MDEF, SymbolTableNode
|
from mypy.nodes import Var, MDEF, SymbolTableNode, TypeInfo, SymbolTable
|
||||||
from mypy.plugin import Plugin, FunctionContext
|
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
|
||||||
|
from mypy.semanal import SemanticAnalyzerPass2
|
||||||
from mypy.types import Type, CallableType, Instance
|
from mypy.types import Type, CallableType, Instance
|
||||||
|
|
||||||
|
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
|
||||||
|
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
|
||||||
|
|
||||||
def extract_to_value_type(ctx: FunctionContext) -> Optional[Type]:
|
|
||||||
|
def extract_to_value_type(ctx: FunctionContext) -> Optional[Instance]:
|
||||||
assert 'to' in ctx.context.arg_names
|
assert 'to' in ctx.context.arg_names
|
||||||
to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0]
|
to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0]
|
||||||
if not isinstance(to_arg_value, CallableType):
|
if not isinstance(to_arg_value, CallableType):
|
||||||
@@ -19,7 +23,22 @@ def extract_related_name_value(ctx: FunctionContext) -> str:
|
|||||||
return ctx.context.args[ctx.context.arg_names.index('related_name')].value
|
return ctx.context.args[ctx.context.arg_names.index('related_name')].value
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_symtable_node_for_class_member(name: str, instance: Instance) -> SymbolTableNode:
|
||||||
|
new_var = Var(name, instance)
|
||||||
|
new_var.info = instance.type
|
||||||
|
|
||||||
|
return SymbolTableNode(MDEF, new_var, plugin_generated=True)
|
||||||
|
|
||||||
|
|
||||||
|
def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instance: Instance) -> None:
|
||||||
|
klass_typeinfo.names[name] = create_new_symtable_node_for_class_member(name,
|
||||||
|
instance=new_member_instance)
|
||||||
|
|
||||||
|
|
||||||
def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
|
def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
|
||||||
|
api = cast(TypeChecker, ctx.api)
|
||||||
|
outer_class_info = api.tscope.classes[-1]
|
||||||
|
|
||||||
if 'related_name' not in ctx.context.arg_names:
|
if 'related_name' not in ctx.context.arg_names:
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
@@ -28,15 +47,10 @@ def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
|
|||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
related_name = extract_related_name_value(ctx)
|
related_name = extract_related_name_value(ctx)
|
||||||
outer_class_info = ctx.api.tscope.classes[-1]
|
queryset_type = api.named_generic_type('django.db.models.QuerySet',
|
||||||
|
args=[Instance(outer_class_info, [])])
|
||||||
queryset_type = ctx.api.named_generic_type('django.db.models.QuerySet',
|
add_new_class_member(referred_to.type,
|
||||||
args=[Instance(outer_class_info, [])])
|
related_name, queryset_type)
|
||||||
related_var = Var(related_name,
|
|
||||||
queryset_type)
|
|
||||||
related_var.info = queryset_type.type
|
|
||||||
referred_to.type.names[related_name] = SymbolTableNode(MDEF, related_var,
|
|
||||||
plugin_generated=True)
|
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
|
||||||
@@ -52,15 +66,27 @@ def set_related_name_instance_for_onetoonefield(ctx: FunctionContext) -> Type:
|
|||||||
outer_class_info = ctx.api.tscope.classes[-1]
|
outer_class_info = ctx.api.tscope.classes[-1]
|
||||||
|
|
||||||
api = cast(TypeChecker, ctx.api)
|
api = cast(TypeChecker, ctx.api)
|
||||||
related_instance_type = api.named_type(outer_class_info.fullname())
|
add_new_class_member(referred_to.type, related_name,
|
||||||
related_var = Var(related_name, related_instance_type)
|
new_member_instance=api.named_type(outer_class_info.fullname()))
|
||||||
related_var.info = related_instance_type.type
|
|
||||||
|
|
||||||
referred_to.type.names[related_name] = SymbolTableNode(MDEF, related_var,
|
|
||||||
plugin_generated=True)
|
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
|
||||||
|
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
|
||||||
|
api = ctx.api
|
||||||
|
|
||||||
|
new_symtable_nodes = SymbolTable()
|
||||||
|
for (name, symtable_node), assignment_stmt in zip(ctx.cls.info.names.items(), ctx.cls.defs.body):
|
||||||
|
rvalue_callee = assignment_stmt.rvalue.callee
|
||||||
|
if rvalue_callee.fullname in {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME}:
|
||||||
|
name += '_id'
|
||||||
|
new_node = create_new_symtable_node_for_class_member(name,
|
||||||
|
instance=api.named_type('__builtins__.int'))
|
||||||
|
new_symtable_nodes[name] = new_node
|
||||||
|
|
||||||
|
for name, node in new_symtable_nodes.items():
|
||||||
|
ctx.cls.info.names[name] = node
|
||||||
|
|
||||||
|
|
||||||
class RelatedFieldsPlugin(Plugin):
|
class RelatedFieldsPlugin(Plugin):
|
||||||
def get_function_hook(self, fullname: str
|
def get_function_hook(self, fullname: str
|
||||||
) -> Optional[Callable[[FunctionContext], Type]]:
|
) -> Optional[Callable[[FunctionContext], Type]]:
|
||||||
@@ -72,6 +98,12 @@ class RelatedFieldsPlugin(Plugin):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_base_class_hook(self, fullname: str
|
||||||
|
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||||
|
if fullname == 'django.db.models.base.Model':
|
||||||
|
return set_fieldname_attrs_for_related_fields
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def plugin(version):
|
def plugin(version):
|
||||||
return RelatedFieldsPlugin
|
return RelatedFieldsPlugin
|
||||||
|
|||||||
@@ -15,6 +15,20 @@ publisher = Publisher()
|
|||||||
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
|
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
|
||||||
[out]
|
[out]
|
||||||
|
|
||||||
|
[case testEveryForeignKeyCreatesFieldNameWithIdAttribute]
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class Publisher(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Book(models.Model):
|
||||||
|
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE,
|
||||||
|
related_name='books')
|
||||||
|
|
||||||
|
book = Book()
|
||||||
|
reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
|
||||||
|
[out]
|
||||||
|
|
||||||
[case testOneToOneField]
|
[case testOneToOneField]
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
@@ -30,3 +44,15 @@ reveal_type(profile.user) # E: Revealed type is 'main.User*'
|
|||||||
user = User()
|
user = User()
|
||||||
reveal_type(user.profile) # E: Revealed type is 'main.Profile'
|
reveal_type(user.profile) # E: Revealed type is 'main.Profile'
|
||||||
|
|
||||||
|
[case testOneToOneFieldAttrWithUnderscoreID]
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class User(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Profile(models.Model):
|
||||||
|
user = models.OneToOneField(to=User, on_delete=models.CASCADE, related_name='profile')
|
||||||
|
|
||||||
|
profile = Profile()
|
||||||
|
reveal_type(profile.user_id) # E: Revealed type is 'builtins.int'
|
||||||
|
[out]
|
||||||
Reference in New Issue
Block a user