diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index c03af70..eb58acc 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -1,11 +1,11 @@ import os -from typing import Callable, Optional, cast +from typing import Callable, Optional, cast, Dict from mypy.checker import TypeChecker +from mypy.nodes import TypeInfo from mypy.options import Options -from mypy.plugin import Plugin, FunctionContext, ClassDefContext, AnalyzeTypeContext +from mypy.plugin import Plugin, FunctionContext, ClassDefContext from mypy.types import Type, Instance -from mypy.typevars import fill_typevars from mypy_django_plugin import helpers, monkeypatch from mypy_django_plugin.plugins.fields import determine_type_of_array_field @@ -14,17 +14,17 @@ from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_ge from mypy_django_plugin.plugins.settings import DjangoConfSettingsInitializerHook -base_model_classes = {helpers.MODEL_CLASS_FULLNAME} -manager_subclasses = set() - - def transform_model_class(ctx: ClassDefContext) -> None: - base_model_classes.add(ctx.cls.fullname) + sym = ctx.api.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME) + if sym is not None and isinstance(sym.node, TypeInfo): + sym.node.metadata['django']['model_bases'][ctx.cls.fullname] = 1 process_model_class(ctx) -def add_new_manager_subclass(ctx: ClassDefContext) -> None: - manager_subclasses.add(ctx.cls.fullname) +def transform_manager_class(ctx: ClassDefContext) -> None: + sym = ctx.api.lookup_fully_qualified_or_none(helpers.MANAGER_CLASS_FULLNAME) + if sym is not None and isinstance(sym.node, TypeInfo): + sym.node.metadata['django']['manager_bases'][ctx.cls.fullname] = 1 def determine_proper_manager_type(ctx: FunctionContext) -> Type: @@ -52,7 +52,6 @@ class DjangoPlugin(Plugin): def __init__(self, options: Options) -> None: super().__init__(options) - monkeypatch.replace_apply_function_plugin_method() monkeypatch.make_inner_classes_with_inherit_from_any_compatible_with_each_other() self.django_settings = os.environ.get('DJANGO_SETTINGS_MODULE') @@ -63,6 +62,28 @@ class DjangoPlugin(Plugin): monkeypatch.restore_original_load_graph() monkeypatch.restore_original_dependencies_handling() + def get_current_model_bases(self) -> Dict[str, int]: + model_sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME) + if model_sym is not None and isinstance(model_sym.node, TypeInfo): + if 'django' not in model_sym.node.metadata: + model_sym.node.metadata['django'] = { + 'model_bases': {helpers.MODEL_CLASS_FULLNAME: 1} + } + return model_sym.node.metadata['django']['model_bases'] + else: + return {} + + def get_current_manager_bases(self) -> Dict[str, int]: + manager_sym = self.lookup_fully_qualified(helpers.MANAGER_CLASS_FULLNAME) + if manager_sym is not None and isinstance(manager_sym.node, TypeInfo): + if 'django' not in manager_sym.node.metadata: + manager_sym.node.metadata['django'] = { + 'manager_bases': {helpers.MANAGER_CLASS_FULLNAME: 1} + } + return manager_sym.node.metadata['django']['manager_bases'] + else: + return {} + def get_function_hook(self, fullname: str ) -> Optional[Callable[[FunctionContext], Type]]: if fullname in {helpers.FOREIGN_KEY_FULLNAME, @@ -73,20 +94,22 @@ class DjangoPlugin(Plugin): if fullname == 'django.contrib.postgres.fields.array.ArrayField': return determine_type_of_array_field - if fullname in manager_subclasses: + manager_bases = self.get_current_manager_bases() + if fullname in manager_bases: return determine_proper_manager_type return None def get_base_class_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: - if fullname in base_model_classes: + if fullname in self.get_current_model_bases(): return transform_model_class if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS: return DjangoConfSettingsInitializerHook(settings_module=self.django_settings) - if fullname in helpers.MANAGER_CLASSES: - return add_new_manager_subclass + if fullname in self.get_current_manager_bases(): + return transform_manager_class + return None diff --git a/mypy_django_plugin/monkeypatch/__init__.py b/mypy_django_plugin/monkeypatch/__init__.py index ffeb914..abc23ea 100644 --- a/mypy_django_plugin/monkeypatch/__init__.py +++ b/mypy_django_plugin/monkeypatch/__init__.py @@ -3,5 +3,4 @@ from .dependencies import (load_graph_to_add_settings_file_as_a_source_seed, restore_original_load_graph, restore_original_dependencies_handling, process_settings_before_dependants) -from .contexts import replace_apply_function_plugin_method from .multiple_inheritance import make_inner_classes_with_inherit_from_any_compatible_with_each_other \ No newline at end of file diff --git a/mypy_django_plugin/monkeypatch/contexts.py b/mypy_django_plugin/monkeypatch/contexts.py deleted file mode 100644 index 740c811..0000000 --- a/mypy_django_plugin/monkeypatch/contexts.py +++ /dev/null @@ -1,275 +0,0 @@ -from typing import Optional, List, Sequence, NamedTuple, Tuple - -from mypy import checkexpr -from mypy.checkexpr import map_actuals_to_formals -from mypy.checkmember import analyze_member_access -from mypy.expandtype import freshen_function_type_vars -from mypy.messages import MessageBuilder -from mypy.nodes import Expression, Context, TypeInfo, FuncDef, Decorator, RefExpr -from mypy.plugin import CheckerPluginInterface -from mypy.subtypes import is_equivalent -from mypy.types import Type, CallableType, Instance, TypeType, Overloaded, AnyType, TypeOfAny, UnionType, TypeVarType, TupleType - - -class PatchedExpressionChecker(checkexpr.ExpressionChecker): - def get_argnames_of_func_node(self, func_node: Context) -> Optional[List[str]]: - if isinstance(func_node, FuncDef): - return func_node.arg_names - if isinstance(func_node, Decorator): - return func_node.func.arg_names - if isinstance(func_node, TypeInfo): - # __init__ method - init_node = func_node.get_method('__init__') - assert isinstance(init_node, FuncDef) - return init_node.arg_names - return None - - def get_defn_arg_names(self, fullname: str, - object_type: Optional[TypeInfo]) -> Optional[List[str]]: - if object_type is not None: - method_name = fullname.rpartition('.')[-1] - sym = object_type.get(method_name) - if (sym is None or sym.node is None - or not isinstance(sym.node, (FuncDef, Decorator, TypeInfo))): - # arg_names extraction is unsupported for sym.node - return None - return self.get_argnames_of_func_node(sym.node) - - sym = self.chk.lookup_qualified(fullname) - if sym.node is None: - return None - return self.get_argnames_of_func_node(sym.node) - - def check_call(self, callee: Type, args: List[Expression], - arg_kinds: List[int], context: Context, - arg_names: Optional[Sequence[Optional[str]]] = None, - callable_node: Optional[Expression] = None, - arg_messages: Optional[MessageBuilder] = None, - callable_name: Optional[str] = None, - object_type: Optional[Type] = None) -> Tuple[Type, Type]: - """Type check a call. - - Also infer type arguments if the callee is a generic function. - - Return (result type, inferred callee type). - - Arguments: - callee: type of the called value - args: actual argument expressions - arg_kinds: contains nodes.ARG_* constant for each argument in args - describing whether the argument is positional, *arg, etc. - arg_names: names of arguments (optional) - callable_node: associate the inferred callable type to this node, - if specified - arg_messages: TODO - callable_name: Fully-qualified name of the function/method to call, - or None if unavailable (examples: 'builtins.open', 'typing.Mapping.get') - object_type: If callable_name refers to a method, the type of the object - on which the method is being called - """ - arg_messages = arg_messages or self.msg - if isinstance(callee, CallableType): - if callable_name is None and callee.name: - callable_name = callee.name - if callee.is_type_obj() and isinstance(callee.ret_type, Instance): - callable_name = callee.ret_type.type.fullname() - if (isinstance(callable_node, RefExpr) - and callable_node.fullname in ('enum.Enum', 'enum.IntEnum', - 'enum.Flag', 'enum.IntFlag')): - # An Enum() call that failed SemanticAnalyzerPass2.check_enum_call(). - return callee.ret_type, callee - - if (callee.is_type_obj() and callee.type_object().is_abstract - # Exception for Type[...] - and not callee.from_type_type - and not callee.type_object().fallback_to_any): - type = callee.type_object() - self.msg.cannot_instantiate_abstract_class( - callee.type_object().name(), type.abstract_attributes, - context) - elif (callee.is_type_obj() and callee.type_object().is_protocol - # Exception for Type[...] - and not callee.from_type_type): - self.chk.fail('Cannot instantiate protocol class "{}"' - .format(callee.type_object().name()), context) - - formal_to_actual = map_actuals_to_formals( - arg_kinds, arg_names, - callee.arg_kinds, callee.arg_names, - lambda i: self.accept(args[i])) - - if callee.is_generic(): - callee = freshen_function_type_vars(callee) - callee = self.infer_function_type_arguments_using_context( - callee, context) - callee = self.infer_function_type_arguments( - callee, args, arg_kinds, formal_to_actual, context) - - arg_types = self.infer_arg_types_in_context( - callee, args, arg_kinds, formal_to_actual) - - self.check_argument_count(callee, arg_types, arg_kinds, - arg_names, formal_to_actual, context, self.msg) - - self.check_argument_types(arg_types, arg_kinds, callee, - formal_to_actual, context, - messages=arg_messages) - - if (callee.is_type_obj() and (len(arg_types) == 1) - and is_equivalent(callee.ret_type, self.named_type('builtins.type'))): - callee = callee.copy_modified(ret_type=TypeType.make_normalized(arg_types[0])) - - if callable_node: - # Store the inferred callable type. - self.chk.store_type(callable_node, callee) - - if (callable_name - and ((object_type is None and self.plugin.get_function_hook(callable_name)) - or (object_type is not None - and self.plugin.get_method_hook(callable_name)))): - ret_type = self.apply_function_plugin( - arg_types, callee.ret_type, arg_names, formal_to_actual, - args, len(callee.arg_types), callable_name, object_type, context) - callee = callee.copy_modified(ret_type=ret_type) - return callee.ret_type, callee - elif isinstance(callee, Overloaded): - arg_types = self.infer_arg_types_in_empty_context(args) - return self.check_overload_call(callee=callee, - args=args, - arg_types=arg_types, - arg_kinds=arg_kinds, - arg_names=arg_names, - callable_name=callable_name, - object_type=object_type, - context=context, - arg_messages=arg_messages) - elif isinstance(callee, AnyType) or not self.chk.in_checked_function(): - self.infer_arg_types_in_empty_context(args) - if isinstance(callee, AnyType): - return (AnyType(TypeOfAny.from_another_any, source_any=callee), - AnyType(TypeOfAny.from_another_any, source_any=callee)) - else: - return AnyType(TypeOfAny.special_form), AnyType(TypeOfAny.special_form) - elif isinstance(callee, UnionType): - self.msg.disable_type_names += 1 - results = [self.check_call(subtype, args, arg_kinds, context, arg_names, - arg_messages=arg_messages) - for subtype in callee.relevant_items()] - self.msg.disable_type_names -= 1 - return (UnionType.make_simplified_union([res[0] for res in results]), - callee) - elif isinstance(callee, Instance): - call_function = analyze_member_access('__call__', callee, context, - False, False, False, self.named_type, - self.not_ready_callback, self.msg, - original_type=callee, chk=self.chk) - return self.check_call(call_function, args, arg_kinds, context, arg_names, - callable_node, arg_messages) - elif isinstance(callee, TypeVarType): - return self.check_call(callee.upper_bound, args, arg_kinds, context, arg_names, - callable_node, arg_messages) - elif isinstance(callee, TypeType): - # Pass the original Type[] as context since that's where errors should go. - item = self.analyze_type_type_callee(callee.item, callee) - return self.check_call(item, args, arg_kinds, context, arg_names, - callable_node, arg_messages) - elif isinstance(callee, TupleType): - return self.check_call(callee.fallback, args, arg_kinds, context, - arg_names, callable_node, arg_messages, callable_name, - object_type) - else: - return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error) - - def apply_function_plugin(self, - arg_types: List[Type], - inferred_ret_type: Type, - arg_names: Optional[Sequence[Optional[str]]], - formal_to_actual: List[List[int]], - args: List[Expression], - num_formals: int, - fullname: str, - object_type: Optional[Type], - context: Context) -> Type: - """Use special case logic to infer the return type of a specific named function/method. - - Caller must ensure that a plugin hook exists. There are two different cases: - - - If object_type is None, the caller must ensure that a function hook exists - for fullname. - - If object_type is not None, the caller must ensure that a method hook exists - for fullname. - - Return the inferred return type. - """ - from mypy.plugin import FunctionContext, MethodContext - - formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]] - formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]] - formal_arg_names = [None for _ in range(num_formals)] # type: List[Optional[str]] - for formal, actuals in enumerate(formal_to_actual): - for actual in actuals: - formal_arg_types[formal].append(arg_types[actual]) - formal_arg_exprs[formal].append(args[actual]) - if arg_names: - formal_arg_names[formal] = arg_names[actual] - - num_passed_positionals = sum([1 if name is None else 0 - for name in formal_arg_names]) - if arg_names and num_passed_positionals > 0: - object_type_info = None - if object_type is not None: - if isinstance(object_type, CallableType): - # class object, convert to corresponding Instance - object_type = object_type.ret_type - if isinstance(object_type, Instance): - # skip TypedDictType and others - object_type_info = object_type.type - - defn_arg_names = self.get_defn_arg_names(fullname, object_type=object_type_info) - if defn_arg_names: - if num_formals < len(defn_arg_names): - # self/cls argument has been passed implicitly - defn_arg_names = defn_arg_names[1:] - formal_arg_names[:num_passed_positionals] = defn_arg_names[:num_passed_positionals] - - if object_type is None: - # Apply function plugin - callback = self.plugin.get_function_hook(fullname) - assert callback is not None # Assume that caller ensures this - return callback( - FunctionContext(formal_arg_names, formal_arg_types, - inferred_ret_type, formal_arg_exprs, - context, self.chk)) - else: - # Apply method plugin - method_callback = self.plugin.get_method_hook(fullname) - assert method_callback is not None # Assume that caller ensures this - return method_callback( - MethodContext(object_type, formal_arg_names, formal_arg_types, - inferred_ret_type, formal_arg_exprs, - context, self.chk)) - - -def replace_apply_function_plugin_method(): - from mypy import plugin - - plugin.FunctionContext = NamedTuple( - 'FunctionContext', [ - ('arg_names', Sequence[Optional[str]]), # List of actual argument names - ('arg_types', List[List[Type]]), # List of actual caller types for each formal argument - ('default_return_type', Type), # Return type inferred from signature - ('args', List[List[Expression]]), # Actual expressions for each formal argument - ('context', Context), - ('api', CheckerPluginInterface)]) - - plugin.MethodContext = NamedTuple( - 'MethodContext', [ - ('type', Type), # Base object type for method call - ('arg_names', Sequence[Optional[str]]), # List of actual argument names - ('arg_types', List[List[Type]]), - ('default_return_type', Type), - ('args', List[List[Expression]]), - ('context', Context), - ('api', CheckerPluginInterface)]) - - checkexpr.ExpressionChecker = PatchedExpressionChecker diff --git a/mypy_django_plugin/plugins/fields.py b/mypy_django_plugin/plugins/fields.py index 25e4171..563653a 100644 --- a/mypy_django_plugin/plugins/fields.py +++ b/mypy_django_plugin/plugins/fields.py @@ -3,9 +3,9 @@ from mypy.types import Type def determine_type_of_array_field(ctx: FunctionContext) -> Type: - if 'base_field' not in ctx.arg_names: + if 'base_field' not in ctx.callee_arg_names: return ctx.default_return_type - base_field_arg_type = ctx.arg_types[ctx.arg_names.index('base_field')][0] + base_field_arg_type = ctx.arg_types[ctx.callee_arg_names.index('base_field')][0] return ctx.api.named_generic_type(ctx.context.callee.fullname, args=[base_field_arg_type.type.names['__get__'].type.ret_type]) diff --git a/mypy_django_plugin/plugins/models.py b/mypy_django_plugin/plugins/models.py index fcb9aaf..8afbb15 100644 --- a/mypy_django_plugin/plugins/models.py +++ b/mypy_django_plugin/plugins/models.py @@ -52,15 +52,6 @@ class ModelClassInitializer(metaclass=ABCMeta): raise NotImplementedError() -def add_new_var_node_to_class(class_type: TypeInfo, name: str, typ: Instance) -> None: - var = Var(name=name, type=typ) - var.info = typ.type - var._fullname = class_type.fullname() + '.' + name - var.is_inferred = True - var.is_initialized_in_class = True - class_type.names[name] = SymbolTableNode(MDEF, var) - - def iter_over_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, Expression]]: for stmt in klass.defs.body: if not isinstance(stmt, AssignmentStmt): diff --git a/mypy_django_plugin/plugins/related_fields.py b/mypy_django_plugin/plugins/related_fields.py index fa89b75..e1a3dd3 100644 --- a/mypy_django_plugin/plugins/related_fields.py +++ b/mypy_django_plugin/plugins/related_fields.py @@ -19,15 +19,15 @@ def fill_typevars_with_any(instance: Instance) -> Type: def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: api = cast(TypeChecker, ctx.api) - if 'to' not in ctx.arg_names: + if 'to' not in ctx.callee_arg_names: # shouldn't happen, invalid code api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}', context=ctx.context) return None - arg_type = ctx.arg_types[ctx.arg_names.index('to')][0] + arg_type = ctx.arg_types[ctx.callee_arg_names.index('to')][0] if not isinstance(arg_type, CallableType): - to_arg_expr = ctx.args[ctx.arg_names.index('to')][0] + to_arg_expr = ctx.args[ctx.callee_arg_names.index('to')][0] if not isinstance(to_arg_expr, StrExpr): # not string, not supported return None diff --git a/setup.py b/setup.py index 36ef80a..d1e4965 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ setup( license='BSD', install_requires=[ 'Django>=2.1.1', - 'mypy' + 'mypy @ git+https://github.com/python/mypy.git#egg=mypy-0.660+dev.01c268644d1d22506442df4e21b39c04710b7e8b' ], - packages=['mypy_django_plugin'], + packages=['mypy_django_plugin'] ) diff --git a/test-data/typecheck/settings.test b/test-data/typecheck/settings.test index 8ec5612..a58d3e0 100644 --- a/test-data/typecheck/settings.test +++ b/test-data/typecheck/settings.test @@ -82,4 +82,4 @@ reveal_type(settings.NOT_EXISTING) # E: Revealed type is 'Any' [env DJANGO_SETTINGS_MODULE=mysettings] [file mysettings.py] -[out] +[out] \ No newline at end of file