move to custom pytest plugin test runner, fix tests, add Any fallback to ForeignKey

This commit is contained in:
Maxim Kurnikov
2018-11-28 00:37:04 +03:00
parent f59cfe6371
commit 64bc053056
14 changed files with 451 additions and 366 deletions

View File

@@ -1,3 +1,4 @@
-r external/mypy/test-requirements.txt -r external/mypy/test-requirements.txt
-e external/mypy -e external/mypy
-e . -e .
decorator

2
external/mypy vendored

View File

@@ -1,5 +1,5 @@
import typing import typing
from typing import Dict, Optional, NamedTuple, Any from typing import Dict, Optional, NamedTuple
from mypy.nodes import SymbolTableNode, Var, Expression from mypy.nodes import SymbolTableNode, Var, Expression
from mypy.plugin import FunctionContext from mypy.plugin import FunctionContext

View File

@@ -1,10 +1,6 @@
import os import os
from typing import Callable, Optional, List from typing import Callable, Optional
from django.apps.registry import Apps
from django.conf import Settings
from mypy import build
from mypy.build import BuildManager
from mypy.options import Options from mypy.options import Options
from mypy.plugin import Plugin, FunctionContext, ClassDefContext from mypy.plugin import Plugin, FunctionContext, ClassDefContext
from mypy.types import Type from mypy.types import Type
@@ -21,10 +17,6 @@ base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
class TransformModelClassHook(object): class TransformModelClassHook(object):
def __init__(self, settings: Settings, apps: Apps):
self.settings = settings
self.apps = apps
def __call__(self, ctx: ClassDefContext) -> None: def __call__(self, ctx: ClassDefContext) -> None:
base_model_classes.add(ctx.cls.fullname) base_model_classes.add(ctx.cls.fullname)
@@ -32,48 +24,27 @@ class TransformModelClassHook(object):
set_objects_queryset_to_model_class(ctx) set_objects_queryset_to_model_class(ctx)
def always_return_none(manager: BuildManager):
return None
build.read_plugins_snapshot = always_return_none
class DjangoPlugin(Plugin): class DjangoPlugin(Plugin):
def __init__(self, def __init__(self,
options: Options) -> None: options: Options) -> None:
super().__init__(options) super().__init__(options)
self.django_settings = None
self.apps = None
monkeypatch.replace_apply_function_plugin_method() monkeypatch.replace_apply_function_plugin_method()
django_settings_module = os.environ.get('DJANGO_SETTINGS_MODULE') self.django_settings = os.environ.get('DJANGO_SETTINGS_MODULE')
if django_settings_module: if self.django_settings:
self.django_settings = Settings(django_settings_module) monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(self.django_settings)
# import django monkeypatch.inject_dependencies(self.django_settings)
# django.setup() else:
# monkeypatch.restore_original_load_graph()
# from django.apps import apps monkeypatch.restore_original_dependencies_handling()
# self.apps = apps
#
# models_modules = []
# for app_config in self.apps.app_configs.values():
# models_modules.append(app_config.module.__name__ + '.' + 'models')
#
# monkeypatch.state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(django_settings_module,
# models_modules)
monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(django_settings_module)
def get_function_hook(self, fullname: str def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]: ) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == helpers.FOREIGN_KEY_FULLNAME: if fullname == helpers.FOREIGN_KEY_FULLNAME:
return ForeignKeyHook(settings=self.django_settings, return ForeignKeyHook(settings=self.django_settings)
apps=self.apps)
if fullname == helpers.ONETOONE_FIELD_FULLNAME: if fullname == helpers.ONETOONE_FIELD_FULLNAME:
return OneToOneFieldHook(settings=self.django_settings, return OneToOneFieldHook(settings=self.django_settings)
apps=self.apps)
if fullname == 'django.contrib.postgres.fields.array.ArrayField': if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field return determine_type_of_array_field
@@ -82,10 +53,10 @@ class DjangoPlugin(Plugin):
def get_base_class_hook(self, fullname: str def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]: ) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in base_model_classes: if fullname in base_model_classes:
return TransformModelClassHook(self.django_settings, self.apps) return TransformModelClassHook()
if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS: if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS:
return DjangoConfSettingsInitializerHook(settings=self.django_settings) return DjangoConfSettingsInitializerHook(settings_module=self.django_settings)
return None return None

View File

@@ -1,112 +0,0 @@
from typing import Optional, List, Sequence
from mypy.build import BuildManager, Graph, State
from mypy.modulefinder import BuildSource
from mypy.nodes import Expression, Context
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import Type, CallableType, Instance
def state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(settings_module: str,
models_py_modules: List[str]):
from mypy.build import State
old_compute_dependencies = State.compute_dependencies
def patched_compute_dependencies(self: State):
old_compute_dependencies(self)
if self.id == settings_module:
self.dependencies.extend(models_py_modules)
State.compute_dependencies = patched_compute_dependencies
def load_graph_to_add_settings_file_as_a_source_seed(settings_module: str):
from mypy import build
old_load_graph = build.load_graph
def patched_load_graph(sources: List[BuildSource], manager: BuildManager,
old_graph: Optional[Graph] = None,
new_modules: Optional[List[State]] = None):
if all([source.module != settings_module for source in sources]):
sources.append(BuildSource(None, settings_module, None))
return old_load_graph(sources=sources, manager=manager,
old_graph=old_graph,
new_modules=new_modules)
build.load_graph = patched_load_graph
def replace_apply_function_plugin_method():
def apply_function_plugin(self,
arg_types: List[Type],
inferred_ret_type: Type,
arg_names: Optional[Sequence[Optional[str]]],
formal_to_actual: List[List[int]],
args: List[Expression],
num_formals: int,
fullname: str,
object_type: Optional[Type],
context: Context) -> Type:
"""Use special case logic to infer the return type of a specific named function/method.
Caller must ensure that a plugin hook exists. There are two different cases:
- If object_type is None, the caller must ensure that a function hook exists
for fullname.
- If object_type is not None, the caller must ensure that a method hook exists
for fullname.
Return the inferred return type.
"""
formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]]
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
formal_arg_names = [None for _ in range(num_formals)] # type: List[Optional[str]]
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_types[formal].append(arg_types[actual])
formal_arg_exprs[formal].append(args[actual])
if arg_names:
formal_arg_names[formal] = arg_names[actual]
num_passed_positionals = sum([1 if name is None else 0
for name in formal_arg_names])
if arg_names and num_passed_positionals > 0:
object_type_info = None
if object_type is not None:
if isinstance(object_type, CallableType):
# class object, convert to corresponding Instance
object_type = object_type.ret_type
if isinstance(object_type, Instance):
# skip TypedDictType and others
object_type_info = object_type.type
defn_arg_names = self._get_defn_arg_names(fullname, object_type=object_type_info)
if defn_arg_names:
if num_formals < len(defn_arg_names):
# self/cls argument has been passed implicitly
defn_arg_names = defn_arg_names[1:]
formal_arg_names[:num_passed_positionals] = defn_arg_names[:num_passed_positionals]
if object_type is None:
# Apply function plugin
callback = self.plugin.get_function_hook(fullname)
assert callback is not None # Assume that caller ensures this
return callback(
FunctionContext(formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
else:
# Apply method plugin
method_callback = self.plugin.get_method_hook(fullname)
assert method_callback is not None # Assume that caller ensures this
return method_callback(
MethodContext(object_type, formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
from mypy.checkexpr import ExpressionChecker
ExpressionChecker.apply_function_plugin = apply_function_plugin

View File

@@ -0,0 +1,5 @@
from .dependencies import (load_graph_to_add_settings_file_as_a_source_seed,
inject_dependencies,
restore_original_load_graph,
restore_original_dependencies_handling)
from .contexts import replace_apply_function_plugin_method

View File

@@ -0,0 +1,276 @@
from typing import Optional, List, Sequence, NamedTuple, Tuple
from mypy import checkexpr
from mypy.argmap import map_actuals_to_formals
from mypy.checkmember import analyze_member_access
from mypy.expandtype import freshen_function_type_vars
from mypy.messages import MessageBuilder
from mypy.nodes import Expression, Context, TypeInfo, FuncDef, Decorator, RefExpr
from mypy.plugin import CheckerPluginInterface
from mypy.subtypes import is_equivalent
from mypy.types import Type, CallableType, Instance, TypeType, Overloaded, AnyType, TypeOfAny, UnionType, TypeVarType, TupleType
class PatchedExpressionChecker(checkexpr.ExpressionChecker):
def get_argnames_of_func_node(self, func_node: Context) -> Optional[List[str]]:
if isinstance(func_node, FuncDef):
return func_node.arg_names
if isinstance(func_node, Decorator):
return func_node.func.arg_names
if isinstance(func_node, TypeInfo):
# __init__ method
init_node = func_node.get_method('__init__')
assert isinstance(init_node, FuncDef)
return init_node.arg_names
return None
def get_defn_arg_names(self, fullname: str,
object_type: Optional[TypeInfo]) -> Optional[List[str]]:
if object_type is not None:
method_name = fullname.rpartition('.')[-1]
sym = object_type.get(method_name)
if (sym is None or sym.node is None
or not isinstance(sym.node, (FuncDef, Decorator, TypeInfo))):
# arg_names extraction is unsupported for sym.node
return None
return self.get_argnames_of_func_node(sym.node)
sym = self.chk.lookup_qualified(fullname)
if sym.node is None:
return None
return self.get_argnames_of_func_node(sym.node)
def check_call(self, callee: Type, args: List[Expression],
arg_kinds: List[int], context: Context,
arg_names: Optional[Sequence[Optional[str]]] = None,
callable_node: Optional[Expression] = None,
arg_messages: Optional[MessageBuilder] = None,
callable_name: Optional[str] = None,
object_type: Optional[Type] = None) -> Tuple[Type, Type]:
"""Type check a call.
Also infer type arguments if the callee is a generic function.
Return (result type, inferred callee type).
Arguments:
callee: type of the called value
args: actual argument expressions
arg_kinds: contains nodes.ARG_* constant for each argument in args
describing whether the argument is positional, *arg, etc.
arg_names: names of arguments (optional)
callable_node: associate the inferred callable type to this node,
if specified
arg_messages: TODO
callable_name: Fully-qualified name of the function/method to call,
or None if unavailable (examples: 'builtins.open', 'typing.Mapping.get')
object_type: If callable_name refers to a method, the type of the object
on which the method is being called
"""
arg_messages = arg_messages or self.msg
if isinstance(callee, CallableType):
if callable_name is None and callee.name:
callable_name = callee.name
if callee.is_type_obj() and isinstance(callee.ret_type, Instance):
callable_name = callee.ret_type.type.fullname()
if (isinstance(callable_node, RefExpr)
and callable_node.fullname in ('enum.Enum', 'enum.IntEnum',
'enum.Flag', 'enum.IntFlag')):
# An Enum() call that failed SemanticAnalyzerPass2.check_enum_call().
return callee.ret_type, callee
if (callee.is_type_obj() and callee.type_object().is_abstract
# Exception for Type[...]
and not callee.from_type_type
and not callee.type_object().fallback_to_any):
type = callee.type_object()
self.msg.cannot_instantiate_abstract_class(
callee.type_object().name(), type.abstract_attributes,
context)
elif (callee.is_type_obj() and callee.type_object().is_protocol
# Exception for Type[...]
and not callee.from_type_type):
self.chk.fail('Cannot instantiate protocol class "{}"'
.format(callee.type_object().name()), context)
formal_to_actual = map_actuals_to_formals(
arg_kinds, arg_names,
callee.arg_kinds, callee.arg_names,
lambda i: self.accept(args[i]))
if callee.is_generic():
callee = freshen_function_type_vars(callee)
callee = self.infer_function_type_arguments_using_context(
callee, context)
callee = self.infer_function_type_arguments(
callee, args, arg_kinds, formal_to_actual, context)
arg_types = self.infer_arg_types_in_context(
callee, args, arg_kinds, formal_to_actual)
self.check_argument_count(callee, arg_types, arg_kinds,
arg_names, formal_to_actual, context, self.msg)
self.check_argument_types(arg_types, arg_kinds, callee,
formal_to_actual, context,
messages=arg_messages)
if (callee.is_type_obj() and (len(arg_types) == 1)
and is_equivalent(callee.ret_type, self.named_type('builtins.type'))):
callee = callee.copy_modified(ret_type=TypeType.make_normalized(arg_types[0]))
if callable_node:
# Store the inferred callable type.
self.chk.store_type(callable_node, callee)
if (callable_name
and ((object_type is None and self.plugin.get_function_hook(callable_name))
or (object_type is not None
and self.plugin.get_method_hook(callable_name)))):
ret_type = self.apply_function_plugin(
arg_types, callee.ret_type, arg_names, formal_to_actual,
args, len(callee.arg_types), callable_name, object_type, context)
callee = callee.copy_modified(ret_type=ret_type)
return callee.ret_type, callee
elif isinstance(callee, Overloaded):
arg_types = self.infer_arg_types_in_empty_context(args)
return self.check_overload_call(callee=callee,
args=args,
arg_types=arg_types,
arg_kinds=arg_kinds,
arg_names=arg_names,
callable_name=callable_name,
object_type=object_type,
context=context,
arg_messages=arg_messages)
elif isinstance(callee, AnyType) or not self.chk.in_checked_function():
self.infer_arg_types_in_empty_context(args)
if isinstance(callee, AnyType):
return (AnyType(TypeOfAny.from_another_any, source_any=callee),
AnyType(TypeOfAny.from_another_any, source_any=callee))
else:
return AnyType(TypeOfAny.special_form), AnyType(TypeOfAny.special_form)
elif isinstance(callee, UnionType):
self.msg.disable_type_names += 1
results = [self.check_call(subtype, args, arg_kinds, context, arg_names,
arg_messages=arg_messages)
for subtype in callee.relevant_items()]
self.msg.disable_type_names -= 1
return (UnionType.make_simplified_union([res[0] for res in results]),
callee)
elif isinstance(callee, Instance):
call_function = analyze_member_access('__call__', callee, context,
False, False, False, self.named_type,
self.not_ready_callback, self.msg,
original_type=callee, chk=self.chk)
return self.check_call(call_function, args, arg_kinds, context, arg_names,
callable_node, arg_messages)
elif isinstance(callee, TypeVarType):
return self.check_call(callee.upper_bound, args, arg_kinds, context, arg_names,
callable_node, arg_messages)
elif isinstance(callee, TypeType):
# Pass the original Type[] as context since that's where errors should go.
item = self.analyze_type_type_callee(callee.item, callee)
return self.check_call(item, args, arg_kinds, context, arg_names,
callable_node, arg_messages)
elif isinstance(callee, TupleType):
return self.check_call(callee.fallback, args, arg_kinds, context,
arg_names, callable_node, arg_messages, callable_name,
object_type)
else:
return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)
def apply_function_plugin(self,
arg_types: List[Type],
inferred_ret_type: Type,
arg_names: Optional[Sequence[Optional[str]]],
formal_to_actual: List[List[int]],
args: List[Expression],
num_formals: int,
fullname: str,
object_type: Optional[Type],
context: Context) -> Type:
"""Use special case logic to infer the return type of a specific named function/method.
Caller must ensure that a plugin hook exists. There are two different cases:
- If object_type is None, the caller must ensure that a function hook exists
for fullname.
- If object_type is not None, the caller must ensure that a method hook exists
for fullname.
Return the inferred return type.
"""
from mypy.plugin import FunctionContext, MethodContext
formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]]
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
formal_arg_names = [None for _ in range(num_formals)] # type: List[Optional[str]]
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_types[formal].append(arg_types[actual])
formal_arg_exprs[formal].append(args[actual])
if arg_names:
formal_arg_names[formal] = arg_names[actual]
num_passed_positionals = sum([1 if name is None else 0
for name in formal_arg_names])
if arg_names and num_passed_positionals > 0:
object_type_info = None
if object_type is not None:
if isinstance(object_type, CallableType):
# class object, convert to corresponding Instance
object_type = object_type.ret_type
if isinstance(object_type, Instance):
# skip TypedDictType and others
object_type_info = object_type.type
defn_arg_names = self.get_defn_arg_names(fullname, object_type=object_type_info)
if defn_arg_names:
if num_formals < len(defn_arg_names):
# self/cls argument has been passed implicitly
defn_arg_names = defn_arg_names[1:]
formal_arg_names[:num_passed_positionals] = defn_arg_names[:num_passed_positionals]
if object_type is None:
# Apply function plugin
callback = self.plugin.get_function_hook(fullname)
assert callback is not None # Assume that caller ensures this
return callback(
FunctionContext(formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
else:
# Apply method plugin
method_callback = self.plugin.get_method_hook(fullname)
assert method_callback is not None # Assume that caller ensures this
return method_callback(
MethodContext(object_type, formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
def replace_apply_function_plugin_method():
from mypy import plugin
plugin.FunctionContext = NamedTuple(
'FunctionContext', [
('arg_names', Sequence[Optional[str]]), # List of actual argument names
('arg_types', List[List[Type]]), # List of actual caller types for each formal argument
('default_return_type', Type), # Return type inferred from signature
('args', List[List[Expression]]), # Actual expressions for each formal argument
('context', Context),
('api', CheckerPluginInterface)])
plugin.MethodContext = NamedTuple(
'MethodContext', [
('type', Type), # Base object type for method call
('arg_names', Sequence[Optional[str]]), # List of actual argument names
('arg_types', List[List[Type]]),
('default_return_type', Type),
('args', List[List[Expression]]),
('context', Context),
('api', CheckerPluginInterface)])
checkexpr.ExpressionChecker = PatchedExpressionChecker

View File

@@ -0,0 +1,52 @@
from typing import List, Optional
from mypy.build import BuildManager, Graph, State
from mypy.modulefinder import BuildSource
def is_module_present_in_sources(module_name: str, sources: List[BuildSource]):
return any([source.module == module_name for source in sources])
from mypy import build
old_load_graph = build.load_graph
OldState = build.State
def load_graph_to_add_settings_file_as_a_source_seed(settings_module: str):
def patched_load_graph(sources: List[BuildSource], manager: BuildManager,
old_graph: Optional[Graph] = None,
new_modules: Optional[List[State]] = None):
if not is_module_present_in_sources(settings_module, sources):
sources.append(BuildSource(None, settings_module, None))
return old_load_graph(sources=sources, manager=manager,
old_graph=old_graph,
new_modules=new_modules)
build.load_graph = patched_load_graph
def restore_original_load_graph():
from mypy import build
build.load_graph = old_load_graph
def inject_dependencies(settings_module: str):
from mypy import build
class PatchedState(build.State):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.id == 'django.conf':
self.dependencies.append(settings_module)
build.State = PatchedState
def restore_original_dependencies_handling():
from mypy import build
build.State = OldState

View File

@@ -1,126 +1,95 @@
import typing import typing
from typing import Optional, cast, Tuple, Any from typing import Optional, cast
from django.apps.registry import Apps
from django.conf import Settings from django.conf import Settings
from django.db import models
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, StrExpr from mypy.nodes import SymbolTable, MDEF, AssignmentStmt
from mypy.plugin import FunctionContext, ClassDefContext from mypy.plugin import FunctionContext, ClassDefContext
from mypy.types import Type, CallableType, Instance, AnyType from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny
from mypy_django_plugin import helpers from mypy_django_plugin import helpers
def get_instance_type_for_class(klass: typing.Type[models.Model],
api: TypeChecker) -> Optional[Instance]:
model_qualname = helpers.get_obj_type_name(klass)
module_name, _, class_name = model_qualname.rpartition('.')
module = api.modules.get(module_name)
if not module or class_name not in module.names:
return
sym = module.names[class_name]
return Instance(sym.node, [])
def extract_to_value_type(ctx: FunctionContext,
apps: Optional[Apps]) -> Tuple[Optional[Instance], bool]:
api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.arg_names:
return None, False
arg = ctx.args[ctx.arg_names.index('to')][0]
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
if isinstance(arg_type, CallableType):
return arg_type.ret_type, False
if apps:
if isinstance(arg, StrExpr):
arg_value = arg.value
if '.' not in arg_value:
return None, False
app_label, modelname = arg_value.lower().split('.')
try:
model_cls = apps.get_model(app_label, modelname)
except LookupError:
# no model class found
return None, False
try:
instance = get_instance_type_for_class(model_cls, api=api)
if not instance:
return None, False
return instance, True
except AssertionError:
pass
return None, False
def extract_related_name_value(ctx: FunctionContext) -> str: def extract_related_name_value(ctx: FunctionContext) -> str:
return ctx.context.args[ctx.context.arg_names.index('related_name')].value return ctx.context.args[ctx.arg_names.index('related_name')].value
def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instance: Instance) -> None: def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
klass_typeinfo.names[name] = helpers.create_new_symtable_node(name, return Instance(instance.type, args=new_typevars)
kind=MDEF,
instance=new_member_instance)
def fill_typevars_with_any(instance: Instance) -> Type:
return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)])
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
if 'to' not in ctx.arg_names:
# shouldn't happen, invalid code
ctx.api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
context=ctx.context)
return None
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
if not isinstance(arg_type, CallableType):
ctx.api.msg.warn(f'to= parameter type {arg_type.__class__.__name__} is not supported',
context=ctx.context)
return None
referred_to_type = arg_type.ret_type
for base in referred_to_type.type.bases:
if base.type.fullname() == helpers.MODEL_CLASS_FULLNAME:
break
else:
ctx.api.msg.fail(f'to= parameter value must be '
f'a subclass of {helpers.MODEL_CLASS_FULLNAME}',
context=ctx.context)
return None
return referred_to_type
class ForeignKeyHook(object): class ForeignKeyHook(object):
def __init__(self, settings: Settings, apps: Apps): def __init__(self, settings: Settings):
self.settings = settings self.settings = settings
self.apps = apps
def __call__(self, ctx: FunctionContext) -> Type: def __call__(self, ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api) api = cast(TypeChecker, ctx.api)
outer_class_info = api.tscope.classes[-1] outer_class_info = api.tscope.classes[-1]
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps) referred_to_type = get_valid_to_value_or_none(ctx)
if not referred_to: if referred_to_type is None:
return ctx.default_return_type return fill_typevars_with_any(ctx.default_return_type)
if 'related_name' in ctx.context.arg_names: if 'related_name' in ctx.arg_names:
related_name = extract_related_name_value(ctx) related_name = extract_related_name_value(ctx)
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME, queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
args=[Instance(outer_class_info, [])]) args=[Instance(outer_class_info, [])])
if isinstance(referred_to, AnyType): sym = helpers.create_new_symtable_node(related_name, MDEF,
return ctx.default_return_type instance=queryset_type)
referred_to_type.type.names[related_name] = sym
add_new_class_member(referred_to.type, return reparametrize_with(ctx.default_return_type, [referred_to_type])
related_name, queryset_type)
if is_string_based:
return referred_to
return ctx.default_return_type
class OneToOneFieldHook(object): class OneToOneFieldHook(object):
def __init__(self, settings: Optional[Settings], apps: Optional[Apps]): def __init__(self, settings: Optional[Settings]):
self.settings = settings self.settings = settings
self.apps = apps
def __call__(self, ctx: FunctionContext) -> Type: def __call__(self, ctx: FunctionContext) -> Type:
if 'related_name' not in ctx.context.arg_names: api = cast(TypeChecker, ctx.api)
return ctx.default_return_type outer_class_info = api.tscope.classes[-1]
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps) referred_to_type = get_valid_to_value_or_none(ctx)
if referred_to is None: if referred_to_type is None:
return ctx.default_return_type return fill_typevars_with_any(ctx.default_return_type)
if 'related_name' in ctx.context.arg_names: if 'related_name' in ctx.arg_names:
related_name = extract_related_name_value(ctx) related_name = extract_related_name_value(ctx)
outer_class_info = ctx.api.tscope.classes[-1] sym = helpers.create_new_symtable_node(related_name, MDEF,
add_new_class_member(referred_to.type, related_name, instance=Instance(outer_class_info, []))
new_member_instance=Instance(outer_class_info, [])) referred_to_type.type.names[related_name] = sym
if is_string_based: return reparametrize_with(ctx.default_return_type, [referred_to_type])
return referred_to
return ctx.default_return_type
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None: def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:

View File

@@ -1,32 +1,42 @@
from typing import cast from typing import Optional, Any, cast
from django.conf import Settings from mypy.nodes import Var, Context, GDEF
from mypy.nodes import MDEF from mypy.options import Options
from mypy.plugin import ClassDefContext from mypy.plugin import ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2 from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Instance, AnyType, TypeOfAny from mypy.types import Instance
from mypy_django_plugin import helpers
def add_settings_to_django_conf_object(ctx: ClassDefContext,
settings_module: str) -> Optional[Any]:
api = cast(SemanticAnalyzerPass2, ctx.api)
if settings_module not in api.modules:
return None
settings_file = api.modules[settings_module]
for name, sym in settings_file.names.items():
if name.isupper():
if not isinstance(sym.node, Var) or not isinstance(sym.type, Instance):
error_context = Context()
error_context.set_line(sym.node)
api.msg.fail("Need type annotation for '{}'".format(sym.node.name()),
context=error_context,
file=settings_file.path,
origin=Context())
continue
sym_copy = sym.copy()
sym_copy.node.info = sym_copy.type.type
sym_copy.kind = GDEF
ctx.cls.info.names[name] = sym_copy
class DjangoConfSettingsInitializerHook(object): class DjangoConfSettingsInitializerHook(object):
def __init__(self, settings: Settings): def __init__(self, settings_module: str):
self.settings = settings self.settings_module = settings_module
def __call__(self, ctx: ClassDefContext) -> None: def __call__(self, ctx: ClassDefContext) -> None:
api = cast(SemanticAnalyzerPass2, ctx.api) if not self.settings_module:
if self.settings: return
for name, value in self.settings.__dict__.items():
if name.isupper():
if value is None:
# TODO: change to Optional[Any] later
ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF,
instance=api.builtin_type('builtins.object'))
continue
type_fullname = helpers.get_obj_type_name(type(value)) add_settings_to_django_conf_object(ctx, self.settings_module)
sym = api.lookup_fully_qualified_or_none(type_fullname)
if sym is not None:
args = len(sym.node.type_vars) * [AnyType(TypeOfAny.from_omitted_generics)]
ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF,
instance=Instance(sym.node, args))

View File

@@ -132,13 +132,11 @@ def assert_string_arrays_equal(expected: List[str], actual: List[str]) -> None:
num_skip_end = _num_skipped_suffix_lines(expected, actual) num_skip_end = _num_skipped_suffix_lines(expected, actual)
error_message += 'Expected:\n' error_message += 'Expected:\n'
# sys.stderr.write('Expected:\n')
# If omit some lines at the beginning, indicate it by displaying a line # If omit some lines at the beginning, indicate it by displaying a line
# with '...'. # with '...'.
if num_skip_start > 0: if num_skip_start > 0:
error_message += ' ...\n' error_message += ' ...\n'
# sys.stderr.write(' ...\n')
# Keep track of the first different line. # Keep track of the first different line.
first_diff = -1 first_diff = -1
@@ -151,51 +149,37 @@ def assert_string_arrays_equal(expected: List[str], actual: List[str]) -> None:
if first_diff < 0: if first_diff < 0:
first_diff = i first_diff = i
error_message += ' {:<45} (diff)'.format(expected[i]) error_message += ' {:<45} (diff)'.format(expected[i])
# sys.stderr.write(' {:<45} (diff)'.format(expected[i]))
else: else:
e = expected[i] e = expected[i]
error_message += ' ' + e[:width] error_message += ' ' + e[:width]
# sys.stderr.write(' ' + e[:width])
if len(e) > width: if len(e) > width:
error_message += '...' error_message += '...'
# sys.stderr.write('...')
error_message += '\n' error_message += '\n'
# sys.stderr.write('\n')
if num_skip_end > 0: if num_skip_end > 0:
error_message += ' ...\n' error_message += ' ...\n'
# sys.stderr.write(' ...\n')
error_message += 'Actual:\n' error_message += 'Actual:\n'
# sys.stderr.write('Actual:\n')
if num_skip_start > 0: if num_skip_start > 0:
error_message += ' ...\n' error_message += ' ...\n'
# sys.stderr.write(' ...\n')
for j in range(num_skip_start, len(actual) - num_skip_end): for j in range(num_skip_start, len(actual) - num_skip_end):
if j >= len(expected) or expected[j] != actual[j]: if j >= len(expected) or expected[j] != actual[j]:
error_message += ' {:<45} (diff)'.format(actual[j]) error_message += ' {:<45} (diff)'.format(actual[j])
# sys.stderr.write(' {:<45} (diff)'.format(actual[j]))
else: else:
a = actual[j] a = actual[j]
error_message += ' ' + a[:width] error_message += ' ' + a[:width]
# sys.stderr.write(' ' + a[:width])
if len(a) > width: if len(a) > width:
error_message += '...' error_message += '...'
# sys.stderr.write('...')
error_message += '\n' error_message += '\n'
# sys.stderr.write('\n')
if actual == []: if actual == []:
error_message += ' (empty)\n' error_message += ' (empty)\n'
# sys.stderr.write(' (empty)\n')
if num_skip_end > 0: if num_skip_end > 0:
error_message += ' ...\n' error_message += ' ...\n'
# sys.stderr.write(' ...\n')
error_message += '\n' error_message += '\n'
# sys.stderr.write('\n')
if first_diff >= 0 and first_diff < len(actual) and ( if 0 <= first_diff < len(actual) and (
len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT): or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT):
# Display message that helps visualize the differences between two # Display message that helps visualize the differences between two

View File

@@ -189,9 +189,9 @@ class MypyTypecheckItem(pytest.Item):
main_fpath.write_text(self.source_code) main_fpath.write_text(self.source_code)
mypy_cmd_options.append(str(main_fpath)) mypy_cmd_options.append(str(main_fpath))
stdout, _, _ = mypy_api.run(mypy_cmd_options) stdout, stderr, returncode = mypy_api.run(mypy_cmd_options)
output_lines = [] output_lines = []
for line in stdout.splitlines(): for line in (stdout + stderr).splitlines():
if ':' not in line: if ':' not in line:
continue continue
out_fpath, res_line = line.split(':', 1) out_fpath, res_line = line.split(':', 1)
@@ -199,15 +199,18 @@ class MypyTypecheckItem(pytest.Item):
output_lines.append(line.strip().replace('.py', '')) output_lines.append(line.strip().replace('.py', ''))
for module in test_specific_modules: for module in test_specific_modules:
if module in sys.modules: parts = module.split('.')
del sys.modules[module] for i in range(len(parts)):
raise ValueError parent_module = '.'.join(parts[:i + 1])
if parent_module in sys.modules:
del sys.modules[parent_module]
assert_string_arrays_equal(expected=self.expected_output_lines, assert_string_arrays_equal(expected=self.expected_output_lines,
actual=output_lines) actual=output_lines)
def prepare_mypy_cmd_options(self, config_file_path: Path) -> List[str]: def prepare_mypy_cmd_options(self, config_file_path: Path) -> List[str]:
mypy_cmd_options = [ mypy_cmd_options = [
'--show-traceback', '--raise-exceptions',
'--no-silence-site-packages' '--no-silence-site-packages'
] ]
python_version = '.'.join([str(part) for part in sys.version_info[:2]]) python_version = '.'.join([str(part) for part in sys.version_info[:2]])
@@ -238,7 +241,7 @@ class MypyTypecheckItem(pytest.Item):
exception_repr.reprtraceback.reprentries = [repr_tb_entry] exception_repr.reprtraceback.reprentries = [repr_tb_entry]
return exception_repr return exception_repr
else: else:
return super().repr_failure(excinfo, style='short') return super().repr_failure(excinfo, style='native')
def reportinfo(self): def reportinfo(self):
return self.fspath, None, get_class_qualname(self.klass) + '::' + self.name return self.fspath, None, get_class_qualname(self.klass) + '::' + self.name
@@ -266,7 +269,7 @@ class MypyTestsCollector(pytest.Class):
current_testcase = cast(MypyTypecheckTestCase, self.obj()) current_testcase = cast(MypyTypecheckTestCase, self.obj())
ini_file_contents = self.get_ini_file_contents(current_testcase.ini_file()) ini_file_contents = self.get_ini_file_contents(current_testcase.ini_file())
for attr_name in dir(current_testcase): for attr_name in dir(current_testcase):
if attr_name.startswith('_test_'): if attr_name.startswith('test_'):
attr = getattr(self.obj, attr_name) attr = getattr(self.obj, attr_name)
if inspect.isfunction(attr): if inspect.isfunction(attr):
first_line_lnum, source_lines = get_func_first_lnum(attr) first_line_lnum, source_lines = get_func_first_lnum(attr)

View File

@@ -9,7 +9,7 @@ class TestParseSettingsFromFile(BaseDjangoPluginTestCase):
reveal_type(settings.ROOT_DIR) # E: Revealed type is 'builtins.str' reveal_type(settings.ROOT_DIR) # E: Revealed type is 'builtins.str'
reveal_type(settings.OBJ) # E: Revealed type is 'django.utils.functional.LazyObject' reveal_type(settings.OBJ) # E: Revealed type is 'django.utils.functional.LazyObject'
reveal_type(settings.NUMBERS) # E: Revealed type is 'builtins.list[Any]' reveal_type(settings.NUMBERS) # E: Revealed type is 'builtins.list[builtins.str]'
reveal_type(settings.DICT) # E: Revealed type is 'builtins.dict[Any, Any]' reveal_type(settings.DICT) # E: Revealed type is 'builtins.dict[Any, Any]'
@file('mysettings.py') @file('mysettings.py')
@@ -34,4 +34,4 @@ class TestSettingInitializableToNone(BaseDjangoPluginTestCase):
@file('mysettings.py') @file('mysettings.py')
def mysettings_py_file(self): def mysettings_py_file(self):
SECRET_KEY = 112233 SECRET_KEY = 112233
NONE_SETTING = None NONE_SETTING: object = None

View File

@@ -1,74 +0,0 @@
from test.pytest_plugin import file, reveal_type, env
from test.pytest_tests.base import BaseDjangoPluginTestCase
class TestForeignKey(BaseDjangoPluginTestCase):
@env(DJANGO_SETTINGS_MODULE='mysettings')
def _test_to_parameter_could_be_specified_as_string(self):
from apps.myapp.models import Publisher
publisher = Publisher()
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[apps.myapp2.models.Book]'
# @env(DJANGO_SETTINGS_MODULE='mysettings')
# def _test_creates_underscore_id_attr(self):
# from apps.myapp2.models import Book
#
# book = Book()
# reveal_type(book.publisher) # E: Revealed type is 'apps.myapp.models.Publisher'
# reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
@file('mysettings.py')
def mysettings(self):
SECRET_KEY = '112233'
ROOT_DIR = '<TMP>'
APPS_DIR = '<TMP>/apps'
INSTALLED_APPS = ('apps.myapp', 'apps.myapp2')
@file('apps/myapp/models.py', make_parent_packages=True)
def apps_myapp_models(self):
from django.db import models
class Publisher(models.Model):
pass
@file('apps/myapp2/models.py', make_parent_packages=True)
def apps_myapp2_models(self):
from django.db import models
class Book(models.Model):
publisher = models.ForeignKey(to='myapp.Publisher', on_delete=models.CASCADE,
related_name='books')
class TestOneToOneField(BaseDjangoPluginTestCase):
@env(DJANGO_SETTINGS_MODULE='mysettings')
def test_to_parameter_could_be_specified_as_string(self):
from apps.myapp.models import User
user = User()
reveal_type(user.profile) # E: Revealed type is 'apps.myapp2.models.Profile'
@file('mysettings.py')
def mysettings(self):
SECRET_KEY = '112233'
ROOT_DIR = '<TMP>'
APPS_DIR = '<TMP>/apps'
INSTALLED_APPS = ('apps.myapp', 'apps.myapp2')
@file('apps/myapp/models.py', make_parent_packages=True)
def apps_myapp_models(self):
from django.db import models
class User(models.Model):
pass
@file('apps/myapp2/models.py', make_parent_packages=True)
def apps_myapp2_models(self):
from django.db import models
class Profile(models.Model):
user = models.OneToOneField(to='myapp.User', on_delete=models.CASCADE,
related_name='profile')