move to custom pytest plugin test runner, fix tests, add Any fallback to ForeignKey

This commit is contained in:
Maxim Kurnikov
2018-11-28 00:37:04 +03:00
parent f59cfe6371
commit 64bc053056
14 changed files with 451 additions and 366 deletions

View File

@@ -1,5 +1,5 @@
import typing
from typing import Dict, Optional, NamedTuple, Any
from typing import Dict, Optional, NamedTuple
from mypy.nodes import SymbolTableNode, Var, Expression
from mypy.plugin import FunctionContext

View File

@@ -1,10 +1,6 @@
import os
from typing import Callable, Optional, List
from typing import Callable, Optional
from django.apps.registry import Apps
from django.conf import Settings
from mypy import build
from mypy.build import BuildManager
from mypy.options import Options
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
from mypy.types import Type
@@ -21,10 +17,6 @@ base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
class TransformModelClassHook(object):
def __init__(self, settings: Settings, apps: Apps):
self.settings = settings
self.apps = apps
def __call__(self, ctx: ClassDefContext) -> None:
base_model_classes.add(ctx.cls.fullname)
@@ -32,48 +24,27 @@ class TransformModelClassHook(object):
set_objects_queryset_to_model_class(ctx)
def always_return_none(manager: BuildManager):
return None
build.read_plugins_snapshot = always_return_none
class DjangoPlugin(Plugin):
def __init__(self,
options: Options) -> None:
super().__init__(options)
self.django_settings = None
self.apps = None
monkeypatch.replace_apply_function_plugin_method()
django_settings_module = os.environ.get('DJANGO_SETTINGS_MODULE')
if django_settings_module:
self.django_settings = Settings(django_settings_module)
# import django
# django.setup()
#
# from django.apps import apps
# self.apps = apps
#
# models_modules = []
# for app_config in self.apps.app_configs.values():
# models_modules.append(app_config.module.__name__ + '.' + 'models')
#
# monkeypatch.state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(django_settings_module,
# models_modules)
monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(django_settings_module)
self.django_settings = os.environ.get('DJANGO_SETTINGS_MODULE')
if self.django_settings:
monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(self.django_settings)
monkeypatch.inject_dependencies(self.django_settings)
else:
monkeypatch.restore_original_load_graph()
monkeypatch.restore_original_dependencies_handling()
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == helpers.FOREIGN_KEY_FULLNAME:
return ForeignKeyHook(settings=self.django_settings,
apps=self.apps)
return ForeignKeyHook(settings=self.django_settings)
if fullname == helpers.ONETOONE_FIELD_FULLNAME:
return OneToOneFieldHook(settings=self.django_settings,
apps=self.apps)
return OneToOneFieldHook(settings=self.django_settings)
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field
@@ -82,10 +53,10 @@ class DjangoPlugin(Plugin):
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in base_model_classes:
return TransformModelClassHook(self.django_settings, self.apps)
return TransformModelClassHook()
if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS:
return DjangoConfSettingsInitializerHook(settings=self.django_settings)
return DjangoConfSettingsInitializerHook(settings_module=self.django_settings)
return None

View File

@@ -1,112 +0,0 @@
from typing import Optional, List, Sequence
from mypy.build import BuildManager, Graph, State
from mypy.modulefinder import BuildSource
from mypy.nodes import Expression, Context
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import Type, CallableType, Instance
def state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(settings_module: str,
models_py_modules: List[str]):
from mypy.build import State
old_compute_dependencies = State.compute_dependencies
def patched_compute_dependencies(self: State):
old_compute_dependencies(self)
if self.id == settings_module:
self.dependencies.extend(models_py_modules)
State.compute_dependencies = patched_compute_dependencies
def load_graph_to_add_settings_file_as_a_source_seed(settings_module: str):
from mypy import build
old_load_graph = build.load_graph
def patched_load_graph(sources: List[BuildSource], manager: BuildManager,
old_graph: Optional[Graph] = None,
new_modules: Optional[List[State]] = None):
if all([source.module != settings_module for source in sources]):
sources.append(BuildSource(None, settings_module, None))
return old_load_graph(sources=sources, manager=manager,
old_graph=old_graph,
new_modules=new_modules)
build.load_graph = patched_load_graph
def replace_apply_function_plugin_method():
def apply_function_plugin(self,
arg_types: List[Type],
inferred_ret_type: Type,
arg_names: Optional[Sequence[Optional[str]]],
formal_to_actual: List[List[int]],
args: List[Expression],
num_formals: int,
fullname: str,
object_type: Optional[Type],
context: Context) -> Type:
"""Use special case logic to infer the return type of a specific named function/method.
Caller must ensure that a plugin hook exists. There are two different cases:
- If object_type is None, the caller must ensure that a function hook exists
for fullname.
- If object_type is not None, the caller must ensure that a method hook exists
for fullname.
Return the inferred return type.
"""
formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]]
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
formal_arg_names = [None for _ in range(num_formals)] # type: List[Optional[str]]
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_types[formal].append(arg_types[actual])
formal_arg_exprs[formal].append(args[actual])
if arg_names:
formal_arg_names[formal] = arg_names[actual]
num_passed_positionals = sum([1 if name is None else 0
for name in formal_arg_names])
if arg_names and num_passed_positionals > 0:
object_type_info = None
if object_type is not None:
if isinstance(object_type, CallableType):
# class object, convert to corresponding Instance
object_type = object_type.ret_type
if isinstance(object_type, Instance):
# skip TypedDictType and others
object_type_info = object_type.type
defn_arg_names = self._get_defn_arg_names(fullname, object_type=object_type_info)
if defn_arg_names:
if num_formals < len(defn_arg_names):
# self/cls argument has been passed implicitly
defn_arg_names = defn_arg_names[1:]
formal_arg_names[:num_passed_positionals] = defn_arg_names[:num_passed_positionals]
if object_type is None:
# Apply function plugin
callback = self.plugin.get_function_hook(fullname)
assert callback is not None # Assume that caller ensures this
return callback(
FunctionContext(formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
else:
# Apply method plugin
method_callback = self.plugin.get_method_hook(fullname)
assert method_callback is not None # Assume that caller ensures this
return method_callback(
MethodContext(object_type, formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
from mypy.checkexpr import ExpressionChecker
ExpressionChecker.apply_function_plugin = apply_function_plugin

View File

@@ -0,0 +1,5 @@
from .dependencies import (load_graph_to_add_settings_file_as_a_source_seed,
inject_dependencies,
restore_original_load_graph,
restore_original_dependencies_handling)
from .contexts import replace_apply_function_plugin_method

View File

@@ -0,0 +1,276 @@
from typing import Optional, List, Sequence, NamedTuple, Tuple
from mypy import checkexpr
from mypy.argmap import map_actuals_to_formals
from mypy.checkmember import analyze_member_access
from mypy.expandtype import freshen_function_type_vars
from mypy.messages import MessageBuilder
from mypy.nodes import Expression, Context, TypeInfo, FuncDef, Decorator, RefExpr
from mypy.plugin import CheckerPluginInterface
from mypy.subtypes import is_equivalent
from mypy.types import Type, CallableType, Instance, TypeType, Overloaded, AnyType, TypeOfAny, UnionType, TypeVarType, TupleType
class PatchedExpressionChecker(checkexpr.ExpressionChecker):
def get_argnames_of_func_node(self, func_node: Context) -> Optional[List[str]]:
if isinstance(func_node, FuncDef):
return func_node.arg_names
if isinstance(func_node, Decorator):
return func_node.func.arg_names
if isinstance(func_node, TypeInfo):
# __init__ method
init_node = func_node.get_method('__init__')
assert isinstance(init_node, FuncDef)
return init_node.arg_names
return None
def get_defn_arg_names(self, fullname: str,
object_type: Optional[TypeInfo]) -> Optional[List[str]]:
if object_type is not None:
method_name = fullname.rpartition('.')[-1]
sym = object_type.get(method_name)
if (sym is None or sym.node is None
or not isinstance(sym.node, (FuncDef, Decorator, TypeInfo))):
# arg_names extraction is unsupported for sym.node
return None
return self.get_argnames_of_func_node(sym.node)
sym = self.chk.lookup_qualified(fullname)
if sym.node is None:
return None
return self.get_argnames_of_func_node(sym.node)
def check_call(self, callee: Type, args: List[Expression],
arg_kinds: List[int], context: Context,
arg_names: Optional[Sequence[Optional[str]]] = None,
callable_node: Optional[Expression] = None,
arg_messages: Optional[MessageBuilder] = None,
callable_name: Optional[str] = None,
object_type: Optional[Type] = None) -> Tuple[Type, Type]:
"""Type check a call.
Also infer type arguments if the callee is a generic function.
Return (result type, inferred callee type).
Arguments:
callee: type of the called value
args: actual argument expressions
arg_kinds: contains nodes.ARG_* constant for each argument in args
describing whether the argument is positional, *arg, etc.
arg_names: names of arguments (optional)
callable_node: associate the inferred callable type to this node,
if specified
arg_messages: TODO
callable_name: Fully-qualified name of the function/method to call,
or None if unavailable (examples: 'builtins.open', 'typing.Mapping.get')
object_type: If callable_name refers to a method, the type of the object
on which the method is being called
"""
arg_messages = arg_messages or self.msg
if isinstance(callee, CallableType):
if callable_name is None and callee.name:
callable_name = callee.name
if callee.is_type_obj() and isinstance(callee.ret_type, Instance):
callable_name = callee.ret_type.type.fullname()
if (isinstance(callable_node, RefExpr)
and callable_node.fullname in ('enum.Enum', 'enum.IntEnum',
'enum.Flag', 'enum.IntFlag')):
# An Enum() call that failed SemanticAnalyzerPass2.check_enum_call().
return callee.ret_type, callee
if (callee.is_type_obj() and callee.type_object().is_abstract
# Exception for Type[...]
and not callee.from_type_type
and not callee.type_object().fallback_to_any):
type = callee.type_object()
self.msg.cannot_instantiate_abstract_class(
callee.type_object().name(), type.abstract_attributes,
context)
elif (callee.is_type_obj() and callee.type_object().is_protocol
# Exception for Type[...]
and not callee.from_type_type):
self.chk.fail('Cannot instantiate protocol class "{}"'
.format(callee.type_object().name()), context)
formal_to_actual = map_actuals_to_formals(
arg_kinds, arg_names,
callee.arg_kinds, callee.arg_names,
lambda i: self.accept(args[i]))
if callee.is_generic():
callee = freshen_function_type_vars(callee)
callee = self.infer_function_type_arguments_using_context(
callee, context)
callee = self.infer_function_type_arguments(
callee, args, arg_kinds, formal_to_actual, context)
arg_types = self.infer_arg_types_in_context(
callee, args, arg_kinds, formal_to_actual)
self.check_argument_count(callee, arg_types, arg_kinds,
arg_names, formal_to_actual, context, self.msg)
self.check_argument_types(arg_types, arg_kinds, callee,
formal_to_actual, context,
messages=arg_messages)
if (callee.is_type_obj() and (len(arg_types) == 1)
and is_equivalent(callee.ret_type, self.named_type('builtins.type'))):
callee = callee.copy_modified(ret_type=TypeType.make_normalized(arg_types[0]))
if callable_node:
# Store the inferred callable type.
self.chk.store_type(callable_node, callee)
if (callable_name
and ((object_type is None and self.plugin.get_function_hook(callable_name))
or (object_type is not None
and self.plugin.get_method_hook(callable_name)))):
ret_type = self.apply_function_plugin(
arg_types, callee.ret_type, arg_names, formal_to_actual,
args, len(callee.arg_types), callable_name, object_type, context)
callee = callee.copy_modified(ret_type=ret_type)
return callee.ret_type, callee
elif isinstance(callee, Overloaded):
arg_types = self.infer_arg_types_in_empty_context(args)
return self.check_overload_call(callee=callee,
args=args,
arg_types=arg_types,
arg_kinds=arg_kinds,
arg_names=arg_names,
callable_name=callable_name,
object_type=object_type,
context=context,
arg_messages=arg_messages)
elif isinstance(callee, AnyType) or not self.chk.in_checked_function():
self.infer_arg_types_in_empty_context(args)
if isinstance(callee, AnyType):
return (AnyType(TypeOfAny.from_another_any, source_any=callee),
AnyType(TypeOfAny.from_another_any, source_any=callee))
else:
return AnyType(TypeOfAny.special_form), AnyType(TypeOfAny.special_form)
elif isinstance(callee, UnionType):
self.msg.disable_type_names += 1
results = [self.check_call(subtype, args, arg_kinds, context, arg_names,
arg_messages=arg_messages)
for subtype in callee.relevant_items()]
self.msg.disable_type_names -= 1
return (UnionType.make_simplified_union([res[0] for res in results]),
callee)
elif isinstance(callee, Instance):
call_function = analyze_member_access('__call__', callee, context,
False, False, False, self.named_type,
self.not_ready_callback, self.msg,
original_type=callee, chk=self.chk)
return self.check_call(call_function, args, arg_kinds, context, arg_names,
callable_node, arg_messages)
elif isinstance(callee, TypeVarType):
return self.check_call(callee.upper_bound, args, arg_kinds, context, arg_names,
callable_node, arg_messages)
elif isinstance(callee, TypeType):
# Pass the original Type[] as context since that's where errors should go.
item = self.analyze_type_type_callee(callee.item, callee)
return self.check_call(item, args, arg_kinds, context, arg_names,
callable_node, arg_messages)
elif isinstance(callee, TupleType):
return self.check_call(callee.fallback, args, arg_kinds, context,
arg_names, callable_node, arg_messages, callable_name,
object_type)
else:
return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)
def apply_function_plugin(self,
arg_types: List[Type],
inferred_ret_type: Type,
arg_names: Optional[Sequence[Optional[str]]],
formal_to_actual: List[List[int]],
args: List[Expression],
num_formals: int,
fullname: str,
object_type: Optional[Type],
context: Context) -> Type:
"""Use special case logic to infer the return type of a specific named function/method.
Caller must ensure that a plugin hook exists. There are two different cases:
- If object_type is None, the caller must ensure that a function hook exists
for fullname.
- If object_type is not None, the caller must ensure that a method hook exists
for fullname.
Return the inferred return type.
"""
from mypy.plugin import FunctionContext, MethodContext
formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]]
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
formal_arg_names = [None for _ in range(num_formals)] # type: List[Optional[str]]
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_types[formal].append(arg_types[actual])
formal_arg_exprs[formal].append(args[actual])
if arg_names:
formal_arg_names[formal] = arg_names[actual]
num_passed_positionals = sum([1 if name is None else 0
for name in formal_arg_names])
if arg_names and num_passed_positionals > 0:
object_type_info = None
if object_type is not None:
if isinstance(object_type, CallableType):
# class object, convert to corresponding Instance
object_type = object_type.ret_type
if isinstance(object_type, Instance):
# skip TypedDictType and others
object_type_info = object_type.type
defn_arg_names = self.get_defn_arg_names(fullname, object_type=object_type_info)
if defn_arg_names:
if num_formals < len(defn_arg_names):
# self/cls argument has been passed implicitly
defn_arg_names = defn_arg_names[1:]
formal_arg_names[:num_passed_positionals] = defn_arg_names[:num_passed_positionals]
if object_type is None:
# Apply function plugin
callback = self.plugin.get_function_hook(fullname)
assert callback is not None # Assume that caller ensures this
return callback(
FunctionContext(formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
else:
# Apply method plugin
method_callback = self.plugin.get_method_hook(fullname)
assert method_callback is not None # Assume that caller ensures this
return method_callback(
MethodContext(object_type, formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
def replace_apply_function_plugin_method():
from mypy import plugin
plugin.FunctionContext = NamedTuple(
'FunctionContext', [
('arg_names', Sequence[Optional[str]]), # List of actual argument names
('arg_types', List[List[Type]]), # List of actual caller types for each formal argument
('default_return_type', Type), # Return type inferred from signature
('args', List[List[Expression]]), # Actual expressions for each formal argument
('context', Context),
('api', CheckerPluginInterface)])
plugin.MethodContext = NamedTuple(
'MethodContext', [
('type', Type), # Base object type for method call
('arg_names', Sequence[Optional[str]]), # List of actual argument names
('arg_types', List[List[Type]]),
('default_return_type', Type),
('args', List[List[Expression]]),
('context', Context),
('api', CheckerPluginInterface)])
checkexpr.ExpressionChecker = PatchedExpressionChecker

View File

@@ -0,0 +1,52 @@
from typing import List, Optional
from mypy.build import BuildManager, Graph, State
from mypy.modulefinder import BuildSource
def is_module_present_in_sources(module_name: str, sources: List[BuildSource]):
return any([source.module == module_name for source in sources])
from mypy import build
old_load_graph = build.load_graph
OldState = build.State
def load_graph_to_add_settings_file_as_a_source_seed(settings_module: str):
def patched_load_graph(sources: List[BuildSource], manager: BuildManager,
old_graph: Optional[Graph] = None,
new_modules: Optional[List[State]] = None):
if not is_module_present_in_sources(settings_module, sources):
sources.append(BuildSource(None, settings_module, None))
return old_load_graph(sources=sources, manager=manager,
old_graph=old_graph,
new_modules=new_modules)
build.load_graph = patched_load_graph
def restore_original_load_graph():
from mypy import build
build.load_graph = old_load_graph
def inject_dependencies(settings_module: str):
from mypy import build
class PatchedState(build.State):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.id == 'django.conf':
self.dependencies.append(settings_module)
build.State = PatchedState
def restore_original_dependencies_handling():
from mypy import build
build.State = OldState

View File

@@ -1,126 +1,95 @@
import typing
from typing import Optional, cast, Tuple, Any
from typing import Optional, cast
from django.apps.registry import Apps
from django.conf import Settings
from django.db import models
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, StrExpr
from mypy.nodes import SymbolTable, MDEF, AssignmentStmt
from mypy.plugin import FunctionContext, ClassDefContext
from mypy.types import Type, CallableType, Instance, AnyType
from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny
from mypy_django_plugin import helpers
def get_instance_type_for_class(klass: typing.Type[models.Model],
api: TypeChecker) -> Optional[Instance]:
model_qualname = helpers.get_obj_type_name(klass)
module_name, _, class_name = model_qualname.rpartition('.')
module = api.modules.get(module_name)
if not module or class_name not in module.names:
return
sym = module.names[class_name]
return Instance(sym.node, [])
def extract_to_value_type(ctx: FunctionContext,
apps: Optional[Apps]) -> Tuple[Optional[Instance], bool]:
api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.arg_names:
return None, False
arg = ctx.args[ctx.arg_names.index('to')][0]
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
if isinstance(arg_type, CallableType):
return arg_type.ret_type, False
if apps:
if isinstance(arg, StrExpr):
arg_value = arg.value
if '.' not in arg_value:
return None, False
app_label, modelname = arg_value.lower().split('.')
try:
model_cls = apps.get_model(app_label, modelname)
except LookupError:
# no model class found
return None, False
try:
instance = get_instance_type_for_class(model_cls, api=api)
if not instance:
return None, False
return instance, True
except AssertionError:
pass
return None, False
def extract_related_name_value(ctx: FunctionContext) -> str:
return ctx.context.args[ctx.context.arg_names.index('related_name')].value
return ctx.context.args[ctx.arg_names.index('related_name')].value
def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instance: Instance) -> None:
klass_typeinfo.names[name] = helpers.create_new_symtable_node(name,
kind=MDEF,
instance=new_member_instance)
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
return Instance(instance.type, args=new_typevars)
def fill_typevars_with_any(instance: Instance) -> Type:
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
if 'to' not in ctx.arg_names:
# shouldn't happen, invalid code
ctx.api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
context=ctx.context)
return None
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
if not isinstance(arg_type, CallableType):
ctx.api.msg.warn(f'to= parameter type {arg_type.__class__.__name__} is not supported',
context=ctx.context)
return None
referred_to_type = arg_type.ret_type
for base in referred_to_type.type.bases:
if base.type.fullname() == helpers.MODEL_CLASS_FULLNAME:
break
else:
ctx.api.msg.fail(f'to= parameter value must be '
f'a subclass of {helpers.MODEL_CLASS_FULLNAME}',
context=ctx.context)
return None
return referred_to_type
class ForeignKeyHook(object):
def __init__(self, settings: Settings, apps: Apps):
def __init__(self, settings: Settings):
self.settings = settings
self.apps = apps
def __call__(self, ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
outer_class_info = api.tscope.classes[-1]
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps)
if not referred_to:
return ctx.default_return_type
referred_to_type = get_valid_to_value_or_none(ctx)
if referred_to_type is None:
return fill_typevars_with_any(ctx.default_return_type)
if 'related_name' in ctx.context.arg_names:
if 'related_name' in ctx.arg_names:
related_name = extract_related_name_value(ctx)
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
args=[Instance(outer_class_info, [])])
if isinstance(referred_to, AnyType):
return ctx.default_return_type
sym = helpers.create_new_symtable_node(related_name, MDEF,
instance=queryset_type)
referred_to_type.type.names[related_name] = sym
add_new_class_member(referred_to.type,
related_name, queryset_type)
if is_string_based:
return referred_to
return ctx.default_return_type
return reparametrize_with(ctx.default_return_type, [referred_to_type])
class OneToOneFieldHook(object):
def __init__(self, settings: Optional[Settings], apps: Optional[Apps]):
def __init__(self, settings: Optional[Settings]):
self.settings = settings
self.apps = apps
def __call__(self, ctx: FunctionContext) -> Type:
if 'related_name' not in ctx.context.arg_names:
return ctx.default_return_type
api = cast(TypeChecker, ctx.api)
outer_class_info = api.tscope.classes[-1]
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps)
if referred_to is None:
return ctx.default_return_type
referred_to_type = get_valid_to_value_or_none(ctx)
if referred_to_type is None:
return fill_typevars_with_any(ctx.default_return_type)
if 'related_name' in ctx.context.arg_names:
if 'related_name' in ctx.arg_names:
related_name = extract_related_name_value(ctx)
outer_class_info = ctx.api.tscope.classes[-1]
add_new_class_member(referred_to.type, related_name,
new_member_instance=Instance(outer_class_info, []))
sym = helpers.create_new_symtable_node(related_name, MDEF,
instance=Instance(outer_class_info, []))
referred_to_type.type.names[related_name] = sym
if is_string_based:
return referred_to
return ctx.default_return_type
return reparametrize_with(ctx.default_return_type, [referred_to_type])
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:

View File

@@ -1,32 +1,42 @@
from typing import cast
from typing import Optional, Any, cast
from django.conf import Settings
from mypy.nodes import MDEF
from mypy.nodes import Var, Context, GDEF
from mypy.options import Options
from mypy.plugin import ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Instance, AnyType, TypeOfAny
from mypy.types import Instance
from mypy_django_plugin import helpers
def add_settings_to_django_conf_object(ctx: ClassDefContext,
settings_module: str) -> Optional[Any]:
api = cast(SemanticAnalyzerPass2, ctx.api)
if settings_module not in api.modules:
return None
settings_file = api.modules[settings_module]
for name, sym in settings_file.names.items():
if name.isupper():
if not isinstance(sym.node, Var) or not isinstance(sym.type, Instance):
error_context = Context()
error_context.set_line(sym.node)
api.msg.fail("Need type annotation for '{}'".format(sym.node.name()),
context=error_context,
file=settings_file.path,
origin=Context())
continue
sym_copy = sym.copy()
sym_copy.node.info = sym_copy.type.type
sym_copy.kind = GDEF
ctx.cls.info.names[name] = sym_copy
class DjangoConfSettingsInitializerHook(object):
def __init__(self, settings: Settings):
self.settings = settings
def __init__(self, settings_module: str):
self.settings_module = settings_module
def __call__(self, ctx: ClassDefContext) -> None:
api = cast(SemanticAnalyzerPass2, ctx.api)
if self.settings:
for name, value in self.settings.__dict__.items():
if name.isupper():
if value is None:
# TODO: change to Optional[Any] later
ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF,
instance=api.builtin_type('builtins.object'))
continue
if not self.settings_module:
return
type_fullname = helpers.get_obj_type_name(type(value))
sym = api.lookup_fully_qualified_or_none(type_fullname)
if sym is not None:
args = len(sym.node.type_vars) * [AnyType(TypeOfAny.from_omitted_generics)]
ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF,
instance=Instance(sym.node, args))
add_settings_to_django_conf_object(ctx, self.settings_module)