mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-08 21:14:49 +08:00
latest changes
This commit is contained in:
@@ -1,23 +1,63 @@
|
||||
from typing import Optional, cast
|
||||
import typing
|
||||
from typing import Optional, cast, Tuple, Any
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.conf import Settings
|
||||
from django.db import models
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, MemberExpr
|
||||
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, StrExpr
|
||||
from mypy.plugin import FunctionContext, ClassDefContext
|
||||
from mypy.types import Type, CallableType, Instance, AnyType
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def extract_to_value_type(ctx: FunctionContext) -> Optional[Instance]:
|
||||
signature = helpers.get_call_signature_or_none(ctx)
|
||||
if signature is None or 'to' not in signature:
|
||||
return None
|
||||
def get_instance_type_for_class(klass: typing.Type[models.Model],
|
||||
api: TypeChecker) -> Optional[Instance]:
|
||||
model_qualname = helpers.get_obj_type_name(klass)
|
||||
module_name, _, class_name = model_qualname.rpartition('.')
|
||||
module = api.modules.get(module_name)
|
||||
if not module or class_name not in module.names:
|
||||
return
|
||||
|
||||
arg, arg_type = signature['to']
|
||||
if not isinstance(arg_type, CallableType):
|
||||
return None
|
||||
sym = module.names[class_name]
|
||||
return Instance(sym.node, [])
|
||||
|
||||
return arg_type.ret_type
|
||||
|
||||
def extract_to_value_type(ctx: FunctionContext,
|
||||
apps: Optional[Apps]) -> Tuple[Optional[Instance], bool]:
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
|
||||
if 'to' not in ctx.arg_names:
|
||||
return None, False
|
||||
arg = ctx.args[ctx.arg_names.index('to')][0]
|
||||
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
|
||||
|
||||
if isinstance(arg_type, CallableType):
|
||||
return arg_type.ret_type, False
|
||||
|
||||
if apps:
|
||||
if isinstance(arg, StrExpr):
|
||||
arg_value = arg.value
|
||||
if '.' not in arg_value:
|
||||
return None, False
|
||||
|
||||
app_label, modelname = arg_value.lower().split('.')
|
||||
try:
|
||||
model_cls = apps.get_model(app_label, modelname)
|
||||
except LookupError:
|
||||
# no model class found
|
||||
return None, False
|
||||
try:
|
||||
instance = get_instance_type_for_class(model_cls, api=api)
|
||||
if not instance:
|
||||
return None, False
|
||||
return instance, True
|
||||
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
return None, False
|
||||
|
||||
|
||||
def extract_related_name_value(ctx: FunctionContext) -> str:
|
||||
@@ -30,45 +70,58 @@ def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instanc
|
||||
instance=new_member_instance)
|
||||
|
||||
|
||||
def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
outer_class_info = api.tscope.classes[-1]
|
||||
class ForeignKeyHook(object):
|
||||
def __init__(self, settings: Settings, apps: Apps):
|
||||
self.settings = settings
|
||||
self.apps = apps
|
||||
|
||||
def __call__(self, ctx: FunctionContext) -> Type:
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
outer_class_info = api.tscope.classes[-1]
|
||||
|
||||
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps)
|
||||
if not referred_to:
|
||||
return ctx.default_return_type
|
||||
|
||||
if 'related_name' in ctx.context.arg_names:
|
||||
related_name = extract_related_name_value(ctx)
|
||||
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
|
||||
args=[Instance(outer_class_info, [])])
|
||||
if isinstance(referred_to, AnyType):
|
||||
return ctx.default_return_type
|
||||
|
||||
add_new_class_member(referred_to.type,
|
||||
related_name, queryset_type)
|
||||
if is_string_based:
|
||||
return referred_to
|
||||
|
||||
if 'related_name' not in ctx.context.arg_names:
|
||||
return ctx.default_return_type
|
||||
|
||||
referred_to = extract_to_value_type(ctx)
|
||||
if not referred_to:
|
||||
|
||||
class OneToOneFieldHook(object):
|
||||
def __init__(self, settings: Optional[Settings], apps: Optional[Apps]):
|
||||
self.settings = settings
|
||||
self.apps = apps
|
||||
|
||||
def __call__(self, ctx: FunctionContext) -> Type:
|
||||
if 'related_name' not in ctx.context.arg_names:
|
||||
return ctx.default_return_type
|
||||
|
||||
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps)
|
||||
if referred_to is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
if 'related_name' in ctx.context.arg_names:
|
||||
related_name = extract_related_name_value(ctx)
|
||||
outer_class_info = ctx.api.tscope.classes[-1]
|
||||
add_new_class_member(referred_to.type, related_name,
|
||||
new_member_instance=Instance(outer_class_info, []))
|
||||
|
||||
if is_string_based:
|
||||
return referred_to
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
related_name = extract_related_name_value(ctx)
|
||||
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
|
||||
args=[Instance(outer_class_info, [])])
|
||||
if isinstance(referred_to, AnyType):
|
||||
# referred_to defined as string, which is unsupported for now
|
||||
return ctx.default_return_type
|
||||
|
||||
add_new_class_member(referred_to.type,
|
||||
related_name, queryset_type)
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def set_related_name_instance_for_onetoonefield(ctx: FunctionContext) -> Type:
|
||||
if 'related_name' not in ctx.context.arg_names:
|
||||
return ctx.default_return_type
|
||||
|
||||
referred_to = extract_to_value_type(ctx)
|
||||
if referred_to is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
related_name = extract_related_name_value(ctx)
|
||||
outer_class_info = ctx.api.tscope.classes[-1]
|
||||
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
add_new_class_member(referred_to.type, related_name,
|
||||
new_member_instance=api.named_type(outer_class_info.fullname()))
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
|
||||
api = ctx.api
|
||||
|
||||
Reference in New Issue
Block a user