move to plugin common api, move to new functioncontext api

This commit is contained in:
Maxim Kurnikov
2018-12-19 02:41:28 +03:00
parent c9ad40d7e3
commit 094b8421ab
8 changed files with 46 additions and 308 deletions

View File

@@ -1,11 +1,11 @@
import os
from typing import Callable, Optional, cast
from typing import Callable, Optional, cast, Dict
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo
from mypy.options import Options
from mypy.plugin import Plugin, FunctionContext, ClassDefContext, AnalyzeTypeContext
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
from mypy.types import Type, Instance
from mypy.typevars import fill_typevars
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
@@ -14,17 +14,17 @@ from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_ge
from mypy_django_plugin.plugins.settings import DjangoConfSettingsInitializerHook
base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
manager_subclasses = set()
def transform_model_class(ctx: ClassDefContext) -> None:
base_model_classes.add(ctx.cls.fullname)
sym = ctx.api.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
sym.node.metadata['django']['model_bases'][ctx.cls.fullname] = 1
process_model_class(ctx)
def add_new_manager_subclass(ctx: ClassDefContext) -> None:
manager_subclasses.add(ctx.cls.fullname)
def transform_manager_class(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(helpers.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
sym.node.metadata['django']['manager_bases'][ctx.cls.fullname] = 1
def determine_proper_manager_type(ctx: FunctionContext) -> Type:
@@ -52,7 +52,6 @@ class DjangoPlugin(Plugin):
def __init__(self,
options: Options) -> None:
super().__init__(options)
monkeypatch.replace_apply_function_plugin_method()
monkeypatch.make_inner_classes_with_inherit_from_any_compatible_with_each_other()
self.django_settings = os.environ.get('DJANGO_SETTINGS_MODULE')
@@ -63,6 +62,28 @@ class DjangoPlugin(Plugin):
monkeypatch.restore_original_load_graph()
monkeypatch.restore_original_dependencies_handling()
def get_current_model_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
if 'django' not in model_sym.node.metadata:
model_sym.node.metadata['django'] = {
'model_bases': {helpers.MODEL_CLASS_FULLNAME: 1}
}
return model_sym.node.metadata['django']['model_bases']
else:
return {}
def get_current_manager_bases(self) -> Dict[str, int]:
manager_sym = self.lookup_fully_qualified(helpers.MANAGER_CLASS_FULLNAME)
if manager_sym is not None and isinstance(manager_sym.node, TypeInfo):
if 'django' not in manager_sym.node.metadata:
manager_sym.node.metadata['django'] = {
'manager_bases': {helpers.MANAGER_CLASS_FULLNAME: 1}
}
return manager_sym.node.metadata['django']['manager_bases']
else:
return {}
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname in {helpers.FOREIGN_KEY_FULLNAME,
@@ -73,20 +94,22 @@ class DjangoPlugin(Plugin):
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field
if fullname in manager_subclasses:
manager_bases = self.get_current_manager_bases()
if fullname in manager_bases:
return determine_proper_manager_type
return None
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in base_model_classes:
if fullname in self.get_current_model_bases():
return transform_model_class
if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS:
return DjangoConfSettingsInitializerHook(settings_module=self.django_settings)
if fullname in helpers.MANAGER_CLASSES:
return add_new_manager_subclass
if fullname in self.get_current_manager_bases():
return transform_manager_class
return None

View File

@@ -3,5 +3,4 @@ from .dependencies import (load_graph_to_add_settings_file_as_a_source_seed,
restore_original_load_graph,
restore_original_dependencies_handling,
process_settings_before_dependants)
from .contexts import replace_apply_function_plugin_method
from .multiple_inheritance import make_inner_classes_with_inherit_from_any_compatible_with_each_other

View File

@@ -1,275 +0,0 @@
from typing import Optional, List, Sequence, NamedTuple, Tuple
from mypy import checkexpr
from mypy.checkexpr import map_actuals_to_formals
from mypy.checkmember import analyze_member_access
from mypy.expandtype import freshen_function_type_vars
from mypy.messages import MessageBuilder
from mypy.nodes import Expression, Context, TypeInfo, FuncDef, Decorator, RefExpr
from mypy.plugin import CheckerPluginInterface
from mypy.subtypes import is_equivalent
from mypy.types import Type, CallableType, Instance, TypeType, Overloaded, AnyType, TypeOfAny, UnionType, TypeVarType, TupleType
class PatchedExpressionChecker(checkexpr.ExpressionChecker):
def get_argnames_of_func_node(self, func_node: Context) -> Optional[List[str]]:
if isinstance(func_node, FuncDef):
return func_node.arg_names
if isinstance(func_node, Decorator):
return func_node.func.arg_names
if isinstance(func_node, TypeInfo):
# __init__ method
init_node = func_node.get_method('__init__')
assert isinstance(init_node, FuncDef)
return init_node.arg_names
return None
def get_defn_arg_names(self, fullname: str,
object_type: Optional[TypeInfo]) -> Optional[List[str]]:
if object_type is not None:
method_name = fullname.rpartition('.')[-1]
sym = object_type.get(method_name)
if (sym is None or sym.node is None
or not isinstance(sym.node, (FuncDef, Decorator, TypeInfo))):
# arg_names extraction is unsupported for sym.node
return None
return self.get_argnames_of_func_node(sym.node)
sym = self.chk.lookup_qualified(fullname)
if sym.node is None:
return None
return self.get_argnames_of_func_node(sym.node)
def check_call(self, callee: Type, args: List[Expression],
arg_kinds: List[int], context: Context,
arg_names: Optional[Sequence[Optional[str]]] = None,
callable_node: Optional[Expression] = None,
arg_messages: Optional[MessageBuilder] = None,
callable_name: Optional[str] = None,
object_type: Optional[Type] = None) -> Tuple[Type, Type]:
"""Type check a call.
Also infer type arguments if the callee is a generic function.
Return (result type, inferred callee type).
Arguments:
callee: type of the called value
args: actual argument expressions
arg_kinds: contains nodes.ARG_* constant for each argument in args
describing whether the argument is positional, *arg, etc.
arg_names: names of arguments (optional)
callable_node: associate the inferred callable type to this node,
if specified
arg_messages: TODO
callable_name: Fully-qualified name of the function/method to call,
or None if unavailable (examples: 'builtins.open', 'typing.Mapping.get')
object_type: If callable_name refers to a method, the type of the object
on which the method is being called
"""
arg_messages = arg_messages or self.msg
if isinstance(callee, CallableType):
if callable_name is None and callee.name:
callable_name = callee.name
if callee.is_type_obj() and isinstance(callee.ret_type, Instance):
callable_name = callee.ret_type.type.fullname()
if (isinstance(callable_node, RefExpr)
and callable_node.fullname in ('enum.Enum', 'enum.IntEnum',
'enum.Flag', 'enum.IntFlag')):
# An Enum() call that failed SemanticAnalyzerPass2.check_enum_call().
return callee.ret_type, callee
if (callee.is_type_obj() and callee.type_object().is_abstract
# Exception for Type[...]
and not callee.from_type_type
and not callee.type_object().fallback_to_any):
type = callee.type_object()
self.msg.cannot_instantiate_abstract_class(
callee.type_object().name(), type.abstract_attributes,
context)
elif (callee.is_type_obj() and callee.type_object().is_protocol
# Exception for Type[...]
and not callee.from_type_type):
self.chk.fail('Cannot instantiate protocol class "{}"'
.format(callee.type_object().name()), context)
formal_to_actual = map_actuals_to_formals(
arg_kinds, arg_names,
callee.arg_kinds, callee.arg_names,
lambda i: self.accept(args[i]))
if callee.is_generic():
callee = freshen_function_type_vars(callee)
callee = self.infer_function_type_arguments_using_context(
callee, context)
callee = self.infer_function_type_arguments(
callee, args, arg_kinds, formal_to_actual, context)
arg_types = self.infer_arg_types_in_context(
callee, args, arg_kinds, formal_to_actual)
self.check_argument_count(callee, arg_types, arg_kinds,
arg_names, formal_to_actual, context, self.msg)
self.check_argument_types(arg_types, arg_kinds, callee,
formal_to_actual, context,
messages=arg_messages)
if (callee.is_type_obj() and (len(arg_types) == 1)
and is_equivalent(callee.ret_type, self.named_type('builtins.type'))):
callee = callee.copy_modified(ret_type=TypeType.make_normalized(arg_types[0]))
if callable_node:
# Store the inferred callable type.
self.chk.store_type(callable_node, callee)
if (callable_name
and ((object_type is None and self.plugin.get_function_hook(callable_name))
or (object_type is not None
and self.plugin.get_method_hook(callable_name)))):
ret_type = self.apply_function_plugin(
arg_types, callee.ret_type, arg_names, formal_to_actual,
args, len(callee.arg_types), callable_name, object_type, context)
callee = callee.copy_modified(ret_type=ret_type)
return callee.ret_type, callee
elif isinstance(callee, Overloaded):
arg_types = self.infer_arg_types_in_empty_context(args)
return self.check_overload_call(callee=callee,
args=args,
arg_types=arg_types,
arg_kinds=arg_kinds,
arg_names=arg_names,
callable_name=callable_name,
object_type=object_type,
context=context,
arg_messages=arg_messages)
elif isinstance(callee, AnyType) or not self.chk.in_checked_function():
self.infer_arg_types_in_empty_context(args)
if isinstance(callee, AnyType):
return (AnyType(TypeOfAny.from_another_any, source_any=callee),
AnyType(TypeOfAny.from_another_any, source_any=callee))
else:
return AnyType(TypeOfAny.special_form), AnyType(TypeOfAny.special_form)
elif isinstance(callee, UnionType):
self.msg.disable_type_names += 1
results = [self.check_call(subtype, args, arg_kinds, context, arg_names,
arg_messages=arg_messages)
for subtype in callee.relevant_items()]
self.msg.disable_type_names -= 1
return (UnionType.make_simplified_union([res[0] for res in results]),
callee)
elif isinstance(callee, Instance):
call_function = analyze_member_access('__call__', callee, context,
False, False, False, self.named_type,
self.not_ready_callback, self.msg,
original_type=callee, chk=self.chk)
return self.check_call(call_function, args, arg_kinds, context, arg_names,
callable_node, arg_messages)
elif isinstance(callee, TypeVarType):
return self.check_call(callee.upper_bound, args, arg_kinds, context, arg_names,
callable_node, arg_messages)
elif isinstance(callee, TypeType):
# Pass the original Type[] as context since that's where errors should go.
item = self.analyze_type_type_callee(callee.item, callee)
return self.check_call(item, args, arg_kinds, context, arg_names,
callable_node, arg_messages)
elif isinstance(callee, TupleType):
return self.check_call(callee.fallback, args, arg_kinds, context,
arg_names, callable_node, arg_messages, callable_name,
object_type)
else:
return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)
def apply_function_plugin(self,
arg_types: List[Type],
inferred_ret_type: Type,
arg_names: Optional[Sequence[Optional[str]]],
formal_to_actual: List[List[int]],
args: List[Expression],
num_formals: int,
fullname: str,
object_type: Optional[Type],
context: Context) -> Type:
"""Use special case logic to infer the return type of a specific named function/method.
Caller must ensure that a plugin hook exists. There are two different cases:
- If object_type is None, the caller must ensure that a function hook exists
for fullname.
- If object_type is not None, the caller must ensure that a method hook exists
for fullname.
Return the inferred return type.
"""
from mypy.plugin import FunctionContext, MethodContext
formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]]
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
formal_arg_names = [None for _ in range(num_formals)] # type: List[Optional[str]]
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_types[formal].append(arg_types[actual])
formal_arg_exprs[formal].append(args[actual])
if arg_names:
formal_arg_names[formal] = arg_names[actual]
num_passed_positionals = sum([1 if name is None else 0
for name in formal_arg_names])
if arg_names and num_passed_positionals > 0:
object_type_info = None
if object_type is not None:
if isinstance(object_type, CallableType):
# class object, convert to corresponding Instance
object_type = object_type.ret_type
if isinstance(object_type, Instance):
# skip TypedDictType and others
object_type_info = object_type.type
defn_arg_names = self.get_defn_arg_names(fullname, object_type=object_type_info)
if defn_arg_names:
if num_formals < len(defn_arg_names):
# self/cls argument has been passed implicitly
defn_arg_names = defn_arg_names[1:]
formal_arg_names[:num_passed_positionals] = defn_arg_names[:num_passed_positionals]
if object_type is None:
# Apply function plugin
callback = self.plugin.get_function_hook(fullname)
assert callback is not None # Assume that caller ensures this
return callback(
FunctionContext(formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
else:
# Apply method plugin
method_callback = self.plugin.get_method_hook(fullname)
assert method_callback is not None # Assume that caller ensures this
return method_callback(
MethodContext(object_type, formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
def replace_apply_function_plugin_method():
from mypy import plugin
plugin.FunctionContext = NamedTuple(
'FunctionContext', [
('arg_names', Sequence[Optional[str]]), # List of actual argument names
('arg_types', List[List[Type]]), # List of actual caller types for each formal argument
('default_return_type', Type), # Return type inferred from signature
('args', List[List[Expression]]), # Actual expressions for each formal argument
('context', Context),
('api', CheckerPluginInterface)])
plugin.MethodContext = NamedTuple(
'MethodContext', [
('type', Type), # Base object type for method call
('arg_names', Sequence[Optional[str]]), # List of actual argument names
('arg_types', List[List[Type]]),
('default_return_type', Type),
('args', List[List[Expression]]),
('context', Context),
('api', CheckerPluginInterface)])
checkexpr.ExpressionChecker = PatchedExpressionChecker

View File

@@ -3,9 +3,9 @@ from mypy.types import Type
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
if 'base_field' not in ctx.arg_names:
if 'base_field' not in ctx.callee_arg_names:
return ctx.default_return_type
base_field_arg_type = ctx.arg_types[ctx.arg_names.index('base_field')][0]
base_field_arg_type = ctx.arg_types[ctx.callee_arg_names.index('base_field')][0]
return ctx.api.named_generic_type(ctx.context.callee.fullname,
args=[base_field_arg_type.type.names['__get__'].type.ret_type])

View File

@@ -52,15 +52,6 @@ class ModelClassInitializer(metaclass=ABCMeta):
raise NotImplementedError()
def add_new_var_node_to_class(class_type: TypeInfo, name: str, typ: Instance) -> None:
var = Var(name=name, type=typ)
var.info = typ.type
var._fullname = class_type.fullname() + '.' + name
var.is_inferred = True
var.is_initialized_in_class = True
class_type.names[name] = SymbolTableNode(MDEF, var)
def iter_over_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, Expression]]:
for stmt in klass.defs.body:
if not isinstance(stmt, AssignmentStmt):

View File

@@ -19,15 +19,15 @@ def fill_typevars_with_any(instance: Instance) -> Type:
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.arg_names:
if 'to' not in ctx.callee_arg_names:
# shouldn't happen, invalid code
api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
context=ctx.context)
return None
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
arg_type = ctx.arg_types[ctx.callee_arg_names.index('to')][0]
if not isinstance(arg_type, CallableType):
to_arg_expr = ctx.args[ctx.arg_names.index('to')][0]
to_arg_expr = ctx.args[ctx.callee_arg_names.index('to')][0]
if not isinstance(to_arg_expr, StrExpr):
# not string, not supported
return None

View File

@@ -20,7 +20,7 @@ setup(
license='BSD',
install_requires=[
'Django>=2.1.1',
'mypy'
'mypy @ git+https://github.com/python/mypy.git#egg=mypy-0.660+dev.01c268644d1d22506442df4e21b39c04710b7e8b'
],
packages=['mypy_django_plugin'],
packages=['mypy_django_plugin']
)