diff --git a/dev-requirements.txt b/dev-requirements.txt index faa3f42..382bdf7 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,4 @@ -r external/mypy/test-requirements.txt -e external/mypy --e . \ No newline at end of file +-e . +decorator \ No newline at end of file diff --git a/external/mypy b/external/mypy index 1a9e280..b790539 160000 --- a/external/mypy +++ b/external/mypy @@ -1 +1 @@ -Subproject commit 1a9e2804cdad401a3019eabd37002f32d08fe0ec +Subproject commit b7905398258304bb366539776a36acd74d6f2a10 diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index 5a1c618..0f003e1 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -1,5 +1,5 @@ import typing -from typing import Dict, Optional, NamedTuple, Any +from typing import Dict, Optional, NamedTuple from mypy.nodes import SymbolTableNode, Var, Expression from mypy.plugin import FunctionContext diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 4f4368a..e1fe731 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -1,10 +1,6 @@ import os -from typing import Callable, Optional, List +from typing import Callable, Optional -from django.apps.registry import Apps -from django.conf import Settings -from mypy import build -from mypy.build import BuildManager from mypy.options import Options from mypy.plugin import Plugin, FunctionContext, ClassDefContext from mypy.types import Type @@ -21,10 +17,6 @@ base_model_classes = {helpers.MODEL_CLASS_FULLNAME} class TransformModelClassHook(object): - def __init__(self, settings: Settings, apps: Apps): - self.settings = settings - self.apps = apps - def __call__(self, ctx: ClassDefContext) -> None: base_model_classes.add(ctx.cls.fullname) @@ -32,48 +24,27 @@ class TransformModelClassHook(object): set_objects_queryset_to_model_class(ctx) -def always_return_none(manager: BuildManager): - return None - - -build.read_plugins_snapshot = always_return_none - - class DjangoPlugin(Plugin): def __init__(self, options: Options) -> None: super().__init__(options) - self.django_settings = None - self.apps = None - monkeypatch.replace_apply_function_plugin_method() - django_settings_module = os.environ.get('DJANGO_SETTINGS_MODULE') - if django_settings_module: - self.django_settings = Settings(django_settings_module) - # import django - # django.setup() - # - # from django.apps import apps - # self.apps = apps - # - # models_modules = [] - # for app_config in self.apps.app_configs.values(): - # models_modules.append(app_config.module.__name__ + '.' + 'models') - # - # monkeypatch.state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(django_settings_module, - # models_modules) - monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(django_settings_module) + self.django_settings = os.environ.get('DJANGO_SETTINGS_MODULE') + if self.django_settings: + monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(self.django_settings) + monkeypatch.inject_dependencies(self.django_settings) + else: + monkeypatch.restore_original_load_graph() + monkeypatch.restore_original_dependencies_handling() def get_function_hook(self, fullname: str ) -> Optional[Callable[[FunctionContext], Type]]: if fullname == helpers.FOREIGN_KEY_FULLNAME: - return ForeignKeyHook(settings=self.django_settings, - apps=self.apps) + return ForeignKeyHook(settings=self.django_settings) if fullname == helpers.ONETOONE_FIELD_FULLNAME: - return OneToOneFieldHook(settings=self.django_settings, - apps=self.apps) + return OneToOneFieldHook(settings=self.django_settings) if fullname == 'django.contrib.postgres.fields.array.ArrayField': return determine_type_of_array_field @@ -82,10 +53,10 @@ class DjangoPlugin(Plugin): def get_base_class_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: if fullname in base_model_classes: - return TransformModelClassHook(self.django_settings, self.apps) + return TransformModelClassHook() if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS: - return DjangoConfSettingsInitializerHook(settings=self.django_settings) + return DjangoConfSettingsInitializerHook(settings_module=self.django_settings) return None diff --git a/mypy_django_plugin/monkeypatch.py b/mypy_django_plugin/monkeypatch.py deleted file mode 100644 index e02b243..0000000 --- a/mypy_django_plugin/monkeypatch.py +++ /dev/null @@ -1,112 +0,0 @@ -from typing import Optional, List, Sequence - -from mypy.build import BuildManager, Graph, State -from mypy.modulefinder import BuildSource -from mypy.nodes import Expression, Context -from mypy.plugin import FunctionContext, MethodContext -from mypy.types import Type, CallableType, Instance - - -def state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(settings_module: str, - models_py_modules: List[str]): - from mypy.build import State - - old_compute_dependencies = State.compute_dependencies - - def patched_compute_dependencies(self: State): - old_compute_dependencies(self) - if self.id == settings_module: - self.dependencies.extend(models_py_modules) - - State.compute_dependencies = patched_compute_dependencies - - -def load_graph_to_add_settings_file_as_a_source_seed(settings_module: str): - from mypy import build - - old_load_graph = build.load_graph - - def patched_load_graph(sources: List[BuildSource], manager: BuildManager, - old_graph: Optional[Graph] = None, - new_modules: Optional[List[State]] = None): - if all([source.module != settings_module for source in sources]): - sources.append(BuildSource(None, settings_module, None)) - - return old_load_graph(sources=sources, manager=manager, - old_graph=old_graph, - new_modules=new_modules) - - build.load_graph = patched_load_graph - - -def replace_apply_function_plugin_method(): - def apply_function_plugin(self, - arg_types: List[Type], - inferred_ret_type: Type, - arg_names: Optional[Sequence[Optional[str]]], - formal_to_actual: List[List[int]], - args: List[Expression], - num_formals: int, - fullname: str, - object_type: Optional[Type], - context: Context) -> Type: - """Use special case logic to infer the return type of a specific named function/method. - - Caller must ensure that a plugin hook exists. There are two different cases: - - - If object_type is None, the caller must ensure that a function hook exists - for fullname. - - If object_type is not None, the caller must ensure that a method hook exists - for fullname. - - Return the inferred return type. - """ - formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]] - formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]] - formal_arg_names = [None for _ in range(num_formals)] # type: List[Optional[str]] - for formal, actuals in enumerate(formal_to_actual): - for actual in actuals: - formal_arg_types[formal].append(arg_types[actual]) - formal_arg_exprs[formal].append(args[actual]) - if arg_names: - formal_arg_names[formal] = arg_names[actual] - - num_passed_positionals = sum([1 if name is None else 0 - for name in formal_arg_names]) - if arg_names and num_passed_positionals > 0: - object_type_info = None - if object_type is not None: - if isinstance(object_type, CallableType): - # class object, convert to corresponding Instance - object_type = object_type.ret_type - if isinstance(object_type, Instance): - # skip TypedDictType and others - object_type_info = object_type.type - - defn_arg_names = self._get_defn_arg_names(fullname, object_type=object_type_info) - if defn_arg_names: - if num_formals < len(defn_arg_names): - # self/cls argument has been passed implicitly - defn_arg_names = defn_arg_names[1:] - formal_arg_names[:num_passed_positionals] = defn_arg_names[:num_passed_positionals] - - if object_type is None: - # Apply function plugin - callback = self.plugin.get_function_hook(fullname) - assert callback is not None # Assume that caller ensures this - return callback( - FunctionContext(formal_arg_names, formal_arg_types, - inferred_ret_type, formal_arg_exprs, - context, self.chk)) - else: - # Apply method plugin - method_callback = self.plugin.get_method_hook(fullname) - assert method_callback is not None # Assume that caller ensures this - return method_callback( - MethodContext(object_type, formal_arg_names, formal_arg_types, - inferred_ret_type, formal_arg_exprs, - context, self.chk)) - - from mypy.checkexpr import ExpressionChecker - ExpressionChecker.apply_function_plugin = apply_function_plugin - diff --git a/mypy_django_plugin/monkeypatch/__init__.py b/mypy_django_plugin/monkeypatch/__init__.py new file mode 100644 index 0000000..60c7ea4 --- /dev/null +++ b/mypy_django_plugin/monkeypatch/__init__.py @@ -0,0 +1,5 @@ +from .dependencies import (load_graph_to_add_settings_file_as_a_source_seed, + inject_dependencies, + restore_original_load_graph, + restore_original_dependencies_handling) +from .contexts import replace_apply_function_plugin_method \ No newline at end of file diff --git a/mypy_django_plugin/monkeypatch/contexts.py b/mypy_django_plugin/monkeypatch/contexts.py new file mode 100644 index 0000000..7d926ca --- /dev/null +++ b/mypy_django_plugin/monkeypatch/contexts.py @@ -0,0 +1,276 @@ +from typing import Optional, List, Sequence, NamedTuple, Tuple + +from mypy import checkexpr +from mypy.argmap import map_actuals_to_formals +from mypy.checkmember import analyze_member_access +from mypy.expandtype import freshen_function_type_vars +from mypy.messages import MessageBuilder +from mypy.nodes import Expression, Context, TypeInfo, FuncDef, Decorator, RefExpr +from mypy.plugin import CheckerPluginInterface +from mypy.subtypes import is_equivalent +from mypy.types import Type, CallableType, Instance, TypeType, Overloaded, AnyType, TypeOfAny, UnionType, TypeVarType, TupleType + + +class PatchedExpressionChecker(checkexpr.ExpressionChecker): + def get_argnames_of_func_node(self, func_node: Context) -> Optional[List[str]]: + if isinstance(func_node, FuncDef): + return func_node.arg_names + if isinstance(func_node, Decorator): + return func_node.func.arg_names + if isinstance(func_node, TypeInfo): + # __init__ method + init_node = func_node.get_method('__init__') + assert isinstance(init_node, FuncDef) + return init_node.arg_names + return None + + def get_defn_arg_names(self, fullname: str, + object_type: Optional[TypeInfo]) -> Optional[List[str]]: + if object_type is not None: + method_name = fullname.rpartition('.')[-1] + sym = object_type.get(method_name) + if (sym is None or sym.node is None + or not isinstance(sym.node, (FuncDef, Decorator, TypeInfo))): + # arg_names extraction is unsupported for sym.node + return None + return self.get_argnames_of_func_node(sym.node) + + sym = self.chk.lookup_qualified(fullname) + if sym.node is None: + return None + return self.get_argnames_of_func_node(sym.node) + + def check_call(self, callee: Type, args: List[Expression], + arg_kinds: List[int], context: Context, + arg_names: Optional[Sequence[Optional[str]]] = None, + callable_node: Optional[Expression] = None, + arg_messages: Optional[MessageBuilder] = None, + callable_name: Optional[str] = None, + object_type: Optional[Type] = None) -> Tuple[Type, Type]: + """Type check a call. + + Also infer type arguments if the callee is a generic function. + + Return (result type, inferred callee type). + + Arguments: + callee: type of the called value + args: actual argument expressions + arg_kinds: contains nodes.ARG_* constant for each argument in args + describing whether the argument is positional, *arg, etc. + arg_names: names of arguments (optional) + callable_node: associate the inferred callable type to this node, + if specified + arg_messages: TODO + callable_name: Fully-qualified name of the function/method to call, + or None if unavailable (examples: 'builtins.open', 'typing.Mapping.get') + object_type: If callable_name refers to a method, the type of the object + on which the method is being called + """ + arg_messages = arg_messages or self.msg + + if isinstance(callee, CallableType): + if callable_name is None and callee.name: + callable_name = callee.name + if callee.is_type_obj() and isinstance(callee.ret_type, Instance): + callable_name = callee.ret_type.type.fullname() + if (isinstance(callable_node, RefExpr) + and callable_node.fullname in ('enum.Enum', 'enum.IntEnum', + 'enum.Flag', 'enum.IntFlag')): + # An Enum() call that failed SemanticAnalyzerPass2.check_enum_call(). + return callee.ret_type, callee + + if (callee.is_type_obj() and callee.type_object().is_abstract + # Exception for Type[...] + and not callee.from_type_type + and not callee.type_object().fallback_to_any): + type = callee.type_object() + self.msg.cannot_instantiate_abstract_class( + callee.type_object().name(), type.abstract_attributes, + context) + elif (callee.is_type_obj() and callee.type_object().is_protocol + # Exception for Type[...] + and not callee.from_type_type): + self.chk.fail('Cannot instantiate protocol class "{}"' + .format(callee.type_object().name()), context) + + formal_to_actual = map_actuals_to_formals( + arg_kinds, arg_names, + callee.arg_kinds, callee.arg_names, + lambda i: self.accept(args[i])) + + if callee.is_generic(): + callee = freshen_function_type_vars(callee) + callee = self.infer_function_type_arguments_using_context( + callee, context) + callee = self.infer_function_type_arguments( + callee, args, arg_kinds, formal_to_actual, context) + + arg_types = self.infer_arg_types_in_context( + callee, args, arg_kinds, formal_to_actual) + + self.check_argument_count(callee, arg_types, arg_kinds, + arg_names, formal_to_actual, context, self.msg) + + self.check_argument_types(arg_types, arg_kinds, callee, + formal_to_actual, context, + messages=arg_messages) + + if (callee.is_type_obj() and (len(arg_types) == 1) + and is_equivalent(callee.ret_type, self.named_type('builtins.type'))): + callee = callee.copy_modified(ret_type=TypeType.make_normalized(arg_types[0])) + + if callable_node: + # Store the inferred callable type. + self.chk.store_type(callable_node, callee) + + if (callable_name + and ((object_type is None and self.plugin.get_function_hook(callable_name)) + or (object_type is not None + and self.plugin.get_method_hook(callable_name)))): + ret_type = self.apply_function_plugin( + arg_types, callee.ret_type, arg_names, formal_to_actual, + args, len(callee.arg_types), callable_name, object_type, context) + callee = callee.copy_modified(ret_type=ret_type) + return callee.ret_type, callee + elif isinstance(callee, Overloaded): + arg_types = self.infer_arg_types_in_empty_context(args) + return self.check_overload_call(callee=callee, + args=args, + arg_types=arg_types, + arg_kinds=arg_kinds, + arg_names=arg_names, + callable_name=callable_name, + object_type=object_type, + context=context, + arg_messages=arg_messages) + elif isinstance(callee, AnyType) or not self.chk.in_checked_function(): + self.infer_arg_types_in_empty_context(args) + if isinstance(callee, AnyType): + return (AnyType(TypeOfAny.from_another_any, source_any=callee), + AnyType(TypeOfAny.from_another_any, source_any=callee)) + else: + return AnyType(TypeOfAny.special_form), AnyType(TypeOfAny.special_form) + elif isinstance(callee, UnionType): + self.msg.disable_type_names += 1 + results = [self.check_call(subtype, args, arg_kinds, context, arg_names, + arg_messages=arg_messages) + for subtype in callee.relevant_items()] + self.msg.disable_type_names -= 1 + return (UnionType.make_simplified_union([res[0] for res in results]), + callee) + elif isinstance(callee, Instance): + call_function = analyze_member_access('__call__', callee, context, + False, False, False, self.named_type, + self.not_ready_callback, self.msg, + original_type=callee, chk=self.chk) + return self.check_call(call_function, args, arg_kinds, context, arg_names, + callable_node, arg_messages) + elif isinstance(callee, TypeVarType): + return self.check_call(callee.upper_bound, args, arg_kinds, context, arg_names, + callable_node, arg_messages) + elif isinstance(callee, TypeType): + # Pass the original Type[] as context since that's where errors should go. + item = self.analyze_type_type_callee(callee.item, callee) + return self.check_call(item, args, arg_kinds, context, arg_names, + callable_node, arg_messages) + elif isinstance(callee, TupleType): + return self.check_call(callee.fallback, args, arg_kinds, context, + arg_names, callable_node, arg_messages, callable_name, + object_type) + else: + return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error) + + def apply_function_plugin(self, + arg_types: List[Type], + inferred_ret_type: Type, + arg_names: Optional[Sequence[Optional[str]]], + formal_to_actual: List[List[int]], + args: List[Expression], + num_formals: int, + fullname: str, + object_type: Optional[Type], + context: Context) -> Type: + """Use special case logic to infer the return type of a specific named function/method. + + Caller must ensure that a plugin hook exists. There are two different cases: + + - If object_type is None, the caller must ensure that a function hook exists + for fullname. + - If object_type is not None, the caller must ensure that a method hook exists + for fullname. + + Return the inferred return type. + """ + from mypy.plugin import FunctionContext, MethodContext + + formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]] + formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]] + formal_arg_names = [None for _ in range(num_formals)] # type: List[Optional[str]] + for formal, actuals in enumerate(formal_to_actual): + for actual in actuals: + formal_arg_types[formal].append(arg_types[actual]) + formal_arg_exprs[formal].append(args[actual]) + if arg_names: + formal_arg_names[formal] = arg_names[actual] + + num_passed_positionals = sum([1 if name is None else 0 + for name in formal_arg_names]) + if arg_names and num_passed_positionals > 0: + object_type_info = None + if object_type is not None: + if isinstance(object_type, CallableType): + # class object, convert to corresponding Instance + object_type = object_type.ret_type + if isinstance(object_type, Instance): + # skip TypedDictType and others + object_type_info = object_type.type + + defn_arg_names = self.get_defn_arg_names(fullname, object_type=object_type_info) + if defn_arg_names: + if num_formals < len(defn_arg_names): + # self/cls argument has been passed implicitly + defn_arg_names = defn_arg_names[1:] + formal_arg_names[:num_passed_positionals] = defn_arg_names[:num_passed_positionals] + + if object_type is None: + # Apply function plugin + callback = self.plugin.get_function_hook(fullname) + assert callback is not None # Assume that caller ensures this + return callback( + FunctionContext(formal_arg_names, formal_arg_types, + inferred_ret_type, formal_arg_exprs, + context, self.chk)) + else: + # Apply method plugin + method_callback = self.plugin.get_method_hook(fullname) + assert method_callback is not None # Assume that caller ensures this + return method_callback( + MethodContext(object_type, formal_arg_names, formal_arg_types, + inferred_ret_type, formal_arg_exprs, + context, self.chk)) + + +def replace_apply_function_plugin_method(): + from mypy import plugin + + plugin.FunctionContext = NamedTuple( + 'FunctionContext', [ + ('arg_names', Sequence[Optional[str]]), # List of actual argument names + ('arg_types', List[List[Type]]), # List of actual caller types for each formal argument + ('default_return_type', Type), # Return type inferred from signature + ('args', List[List[Expression]]), # Actual expressions for each formal argument + ('context', Context), + ('api', CheckerPluginInterface)]) + + plugin.MethodContext = NamedTuple( + 'MethodContext', [ + ('type', Type), # Base object type for method call + ('arg_names', Sequence[Optional[str]]), # List of actual argument names + ('arg_types', List[List[Type]]), + ('default_return_type', Type), + ('args', List[List[Expression]]), + ('context', Context), + ('api', CheckerPluginInterface)]) + + checkexpr.ExpressionChecker = PatchedExpressionChecker diff --git a/mypy_django_plugin/monkeypatch/dependencies.py b/mypy_django_plugin/monkeypatch/dependencies.py new file mode 100644 index 0000000..5044e8c --- /dev/null +++ b/mypy_django_plugin/monkeypatch/dependencies.py @@ -0,0 +1,52 @@ +from typing import List, Optional + +from mypy.build import BuildManager, Graph, State +from mypy.modulefinder import BuildSource + + +def is_module_present_in_sources(module_name: str, sources: List[BuildSource]): + return any([source.module == module_name for source in sources]) + + +from mypy import build + +old_load_graph = build.load_graph +OldState = build.State + + +def load_graph_to_add_settings_file_as_a_source_seed(settings_module: str): + def patched_load_graph(sources: List[BuildSource], manager: BuildManager, + old_graph: Optional[Graph] = None, + new_modules: Optional[List[State]] = None): + if not is_module_present_in_sources(settings_module, sources): + sources.append(BuildSource(None, settings_module, None)) + + return old_load_graph(sources=sources, manager=manager, + old_graph=old_graph, + new_modules=new_modules) + + build.load_graph = patched_load_graph + + +def restore_original_load_graph(): + from mypy import build + + build.load_graph = old_load_graph + + +def inject_dependencies(settings_module: str): + from mypy import build + + class PatchedState(build.State): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.id == 'django.conf': + self.dependencies.append(settings_module) + + build.State = PatchedState + + +def restore_original_dependencies_handling(): + from mypy import build + + build.State = OldState diff --git a/mypy_django_plugin/plugins/related_fields.py b/mypy_django_plugin/plugins/related_fields.py index 7abb97c..4aab943 100644 --- a/mypy_django_plugin/plugins/related_fields.py +++ b/mypy_django_plugin/plugins/related_fields.py @@ -1,126 +1,95 @@ import typing -from typing import Optional, cast, Tuple, Any +from typing import Optional, cast -from django.apps.registry import Apps from django.conf import Settings -from django.db import models from mypy.checker import TypeChecker -from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, StrExpr +from mypy.nodes import SymbolTable, MDEF, AssignmentStmt from mypy.plugin import FunctionContext, ClassDefContext -from mypy.types import Type, CallableType, Instance, AnyType +from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny from mypy_django_plugin import helpers -def get_instance_type_for_class(klass: typing.Type[models.Model], - api: TypeChecker) -> Optional[Instance]: - model_qualname = helpers.get_obj_type_name(klass) - module_name, _, class_name = model_qualname.rpartition('.') - module = api.modules.get(module_name) - if not module or class_name not in module.names: - return - - sym = module.names[class_name] - return Instance(sym.node, []) - - -def extract_to_value_type(ctx: FunctionContext, - apps: Optional[Apps]) -> Tuple[Optional[Instance], bool]: - api = cast(TypeChecker, ctx.api) - - if 'to' not in ctx.arg_names: - return None, False - arg = ctx.args[ctx.arg_names.index('to')][0] - arg_type = ctx.arg_types[ctx.arg_names.index('to')][0] - - if isinstance(arg_type, CallableType): - return arg_type.ret_type, False - - if apps: - if isinstance(arg, StrExpr): - arg_value = arg.value - if '.' not in arg_value: - return None, False - - app_label, modelname = arg_value.lower().split('.') - try: - model_cls = apps.get_model(app_label, modelname) - except LookupError: - # no model class found - return None, False - try: - instance = get_instance_type_for_class(model_cls, api=api) - if not instance: - return None, False - return instance, True - - except AssertionError: - pass - - return None, False - - def extract_related_name_value(ctx: FunctionContext) -> str: - return ctx.context.args[ctx.context.arg_names.index('related_name')].value + return ctx.context.args[ctx.arg_names.index('related_name')].value -def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instance: Instance) -> None: - klass_typeinfo.names[name] = helpers.create_new_symtable_node(name, - kind=MDEF, - instance=new_member_instance) +def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]): + return Instance(instance.type, args=new_typevars) + + +def fill_typevars_with_any(instance: Instance) -> Type: + return reparametrize_with(instance, [AnyType(TypeOfAny.unannotated)]) + + +def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: + if 'to' not in ctx.arg_names: + # shouldn't happen, invalid code + ctx.api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}', + context=ctx.context) + return None + + arg_type = ctx.arg_types[ctx.arg_names.index('to')][0] + if not isinstance(arg_type, CallableType): + ctx.api.msg.warn(f'to= parameter type {arg_type.__class__.__name__} is not supported', + context=ctx.context) + return None + + referred_to_type = arg_type.ret_type + for base in referred_to_type.type.bases: + if base.type.fullname() == helpers.MODEL_CLASS_FULLNAME: + break + else: + ctx.api.msg.fail(f'to= parameter value must be ' + f'a subclass of {helpers.MODEL_CLASS_FULLNAME}', + context=ctx.context) + return None + + return referred_to_type class ForeignKeyHook(object): - def __init__(self, settings: Settings, apps: Apps): + def __init__(self, settings: Settings): self.settings = settings - self.apps = apps def __call__(self, ctx: FunctionContext) -> Type: api = cast(TypeChecker, ctx.api) outer_class_info = api.tscope.classes[-1] - referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps) - if not referred_to: - return ctx.default_return_type + referred_to_type = get_valid_to_value_or_none(ctx) + if referred_to_type is None: + return fill_typevars_with_any(ctx.default_return_type) - if 'related_name' in ctx.context.arg_names: + if 'related_name' in ctx.arg_names: related_name = extract_related_name_value(ctx) queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME, args=[Instance(outer_class_info, [])]) - if isinstance(referred_to, AnyType): - return ctx.default_return_type + sym = helpers.create_new_symtable_node(related_name, MDEF, + instance=queryset_type) + referred_to_type.type.names[related_name] = sym - add_new_class_member(referred_to.type, - related_name, queryset_type) - if is_string_based: - return referred_to - - return ctx.default_return_type + return reparametrize_with(ctx.default_return_type, [referred_to_type]) class OneToOneFieldHook(object): - def __init__(self, settings: Optional[Settings], apps: Optional[Apps]): + def __init__(self, settings: Optional[Settings]): self.settings = settings - self.apps = apps def __call__(self, ctx: FunctionContext) -> Type: - if 'related_name' not in ctx.context.arg_names: - return ctx.default_return_type + api = cast(TypeChecker, ctx.api) + outer_class_info = api.tscope.classes[-1] - referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps) - if referred_to is None: - return ctx.default_return_type + referred_to_type = get_valid_to_value_or_none(ctx) + if referred_to_type is None: + return fill_typevars_with_any(ctx.default_return_type) - if 'related_name' in ctx.context.arg_names: + if 'related_name' in ctx.arg_names: related_name = extract_related_name_value(ctx) - outer_class_info = ctx.api.tscope.classes[-1] - add_new_class_member(referred_to.type, related_name, - new_member_instance=Instance(outer_class_info, [])) + sym = helpers.create_new_symtable_node(related_name, MDEF, + instance=Instance(outer_class_info, [])) + referred_to_type.type.names[related_name] = sym - if is_string_based: - return referred_to - - return ctx.default_return_type + return reparametrize_with(ctx.default_return_type, [referred_to_type]) def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None: diff --git a/mypy_django_plugin/plugins/setup_settings.py b/mypy_django_plugin/plugins/setup_settings.py index c866b61..2aaf4ec 100644 --- a/mypy_django_plugin/plugins/setup_settings.py +++ b/mypy_django_plugin/plugins/setup_settings.py @@ -1,32 +1,42 @@ -from typing import cast +from typing import Optional, Any, cast -from django.conf import Settings -from mypy.nodes import MDEF +from mypy.nodes import Var, Context, GDEF +from mypy.options import Options from mypy.plugin import ClassDefContext from mypy.semanal import SemanticAnalyzerPass2 -from mypy.types import Instance, AnyType, TypeOfAny +from mypy.types import Instance -from mypy_django_plugin import helpers + +def add_settings_to_django_conf_object(ctx: ClassDefContext, + settings_module: str) -> Optional[Any]: + api = cast(SemanticAnalyzerPass2, ctx.api) + if settings_module not in api.modules: + return None + + settings_file = api.modules[settings_module] + for name, sym in settings_file.names.items(): + if name.isupper(): + if not isinstance(sym.node, Var) or not isinstance(sym.type, Instance): + error_context = Context() + error_context.set_line(sym.node) + api.msg.fail("Need type annotation for '{}'".format(sym.node.name()), + context=error_context, + file=settings_file.path, + origin=Context()) + continue + + sym_copy = sym.copy() + sym_copy.node.info = sym_copy.type.type + sym_copy.kind = GDEF + ctx.cls.info.names[name] = sym_copy class DjangoConfSettingsInitializerHook(object): - def __init__(self, settings: Settings): - self.settings = settings + def __init__(self, settings_module: str): + self.settings_module = settings_module def __call__(self, ctx: ClassDefContext) -> None: - api = cast(SemanticAnalyzerPass2, ctx.api) - if self.settings: - for name, value in self.settings.__dict__.items(): - if name.isupper(): - if value is None: - # TODO: change to Optional[Any] later - ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF, - instance=api.builtin_type('builtins.object')) - continue + if not self.settings_module: + return - type_fullname = helpers.get_obj_type_name(type(value)) - sym = api.lookup_fully_qualified_or_none(type_fullname) - if sym is not None: - args = len(sym.node.type_vars) * [AnyType(TypeOfAny.from_omitted_generics)] - ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF, - instance=Instance(sym.node, args)) + add_settings_to_django_conf_object(ctx, self.settings_module) diff --git a/test/helpers.py b/test/helpers.py index b4a2916..fb8c07f 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -132,13 +132,11 @@ def assert_string_arrays_equal(expected: List[str], actual: List[str]) -> None: num_skip_end = _num_skipped_suffix_lines(expected, actual) error_message += 'Expected:\n' - # sys.stderr.write('Expected:\n') # If omit some lines at the beginning, indicate it by displaying a line # with '...'. if num_skip_start > 0: error_message += ' ...\n' - # sys.stderr.write(' ...\n') # Keep track of the first different line. first_diff = -1 @@ -151,51 +149,37 @@ def assert_string_arrays_equal(expected: List[str], actual: List[str]) -> None: if first_diff < 0: first_diff = i error_message += ' {:<45} (diff)'.format(expected[i]) - # sys.stderr.write(' {:<45} (diff)'.format(expected[i])) else: e = expected[i] error_message += ' ' + e[:width] - # sys.stderr.write(' ' + e[:width]) if len(e) > width: error_message += '...' - # sys.stderr.write('...') error_message += '\n' - # sys.stderr.write('\n') if num_skip_end > 0: error_message += ' ...\n' - # sys.stderr.write(' ...\n') error_message += 'Actual:\n' - # sys.stderr.write('Actual:\n') if num_skip_start > 0: error_message += ' ...\n' - # sys.stderr.write(' ...\n') for j in range(num_skip_start, len(actual) - num_skip_end): if j >= len(expected) or expected[j] != actual[j]: error_message += ' {:<45} (diff)'.format(actual[j]) - # sys.stderr.write(' {:<45} (diff)'.format(actual[j])) else: a = actual[j] error_message += ' ' + a[:width] - # sys.stderr.write(' ' + a[:width]) if len(a) > width: error_message += '...' - # sys.stderr.write('...') error_message += '\n' - # sys.stderr.write('\n') if actual == []: error_message += ' (empty)\n' - # sys.stderr.write(' (empty)\n') if num_skip_end > 0: error_message += ' ...\n' - # sys.stderr.write(' ...\n') error_message += '\n' - # sys.stderr.write('\n') - if first_diff >= 0 and first_diff < len(actual) and ( + if 0 <= first_diff < len(actual) and ( len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT): # Display message that helps visualize the differences between two diff --git a/test/pytest_plugin.py b/test/pytest_plugin.py index 26bd4d3..57ada78 100644 --- a/test/pytest_plugin.py +++ b/test/pytest_plugin.py @@ -189,9 +189,9 @@ class MypyTypecheckItem(pytest.Item): main_fpath.write_text(self.source_code) mypy_cmd_options.append(str(main_fpath)) - stdout, _, _ = mypy_api.run(mypy_cmd_options) + stdout, stderr, returncode = mypy_api.run(mypy_cmd_options) output_lines = [] - for line in stdout.splitlines(): + for line in (stdout + stderr).splitlines(): if ':' not in line: continue out_fpath, res_line = line.split(':', 1) @@ -199,15 +199,18 @@ class MypyTypecheckItem(pytest.Item): output_lines.append(line.strip().replace('.py', '')) for module in test_specific_modules: - if module in sys.modules: - del sys.modules[module] - raise ValueError + parts = module.split('.') + for i in range(len(parts)): + parent_module = '.'.join(parts[:i + 1]) + if parent_module in sys.modules: + del sys.modules[parent_module] + assert_string_arrays_equal(expected=self.expected_output_lines, actual=output_lines) def prepare_mypy_cmd_options(self, config_file_path: Path) -> List[str]: mypy_cmd_options = [ - '--show-traceback', + '--raise-exceptions', '--no-silence-site-packages' ] python_version = '.'.join([str(part) for part in sys.version_info[:2]]) @@ -238,7 +241,7 @@ class MypyTypecheckItem(pytest.Item): exception_repr.reprtraceback.reprentries = [repr_tb_entry] return exception_repr else: - return super().repr_failure(excinfo, style='short') + return super().repr_failure(excinfo, style='native') def reportinfo(self): return self.fspath, None, get_class_qualname(self.klass) + '::' + self.name @@ -266,7 +269,7 @@ class MypyTestsCollector(pytest.Class): current_testcase = cast(MypyTypecheckTestCase, self.obj()) ini_file_contents = self.get_ini_file_contents(current_testcase.ini_file()) for attr_name in dir(current_testcase): - if attr_name.startswith('_test_'): + if attr_name.startswith('test_'): attr = getattr(self.obj, attr_name) if inspect.isfunction(attr): first_line_lnum, source_lines = get_func_first_lnum(attr) diff --git a/test/pytest_tests/test_parse_settings.py b/test/pytest_tests/test_parse_settings.py index 3436681..e333857 100644 --- a/test/pytest_tests/test_parse_settings.py +++ b/test/pytest_tests/test_parse_settings.py @@ -9,7 +9,7 @@ class TestParseSettingsFromFile(BaseDjangoPluginTestCase): reveal_type(settings.ROOT_DIR) # E: Revealed type is 'builtins.str' reveal_type(settings.OBJ) # E: Revealed type is 'django.utils.functional.LazyObject' - reveal_type(settings.NUMBERS) # E: Revealed type is 'builtins.list[Any]' + reveal_type(settings.NUMBERS) # E: Revealed type is 'builtins.list[builtins.str]' reveal_type(settings.DICT) # E: Revealed type is 'builtins.dict[Any, Any]' @file('mysettings.py') @@ -34,4 +34,4 @@ class TestSettingInitializableToNone(BaseDjangoPluginTestCase): @file('mysettings.py') def mysettings_py_file(self): SECRET_KEY = 112233 - NONE_SETTING = None + NONE_SETTING: object = None diff --git a/test/pytest_tests/test_to_attr_as_string.py b/test/pytest_tests/test_to_attr_as_string.py deleted file mode 100644 index 20cac2d..0000000 --- a/test/pytest_tests/test_to_attr_as_string.py +++ /dev/null @@ -1,74 +0,0 @@ -from test.pytest_plugin import file, reveal_type, env -from test.pytest_tests.base import BaseDjangoPluginTestCase - - -class TestForeignKey(BaseDjangoPluginTestCase): - @env(DJANGO_SETTINGS_MODULE='mysettings') - def _test_to_parameter_could_be_specified_as_string(self): - from apps.myapp.models import Publisher - - publisher = Publisher() - reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[apps.myapp2.models.Book]' - - # @env(DJANGO_SETTINGS_MODULE='mysettings') - # def _test_creates_underscore_id_attr(self): - # from apps.myapp2.models import Book - # - # book = Book() - # reveal_type(book.publisher) # E: Revealed type is 'apps.myapp.models.Publisher' - # reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int' - - @file('mysettings.py') - def mysettings(self): - SECRET_KEY = '112233' - ROOT_DIR = '' - APPS_DIR = '/apps' - - INSTALLED_APPS = ('apps.myapp', 'apps.myapp2') - - @file('apps/myapp/models.py', make_parent_packages=True) - def apps_myapp_models(self): - from django.db import models - - class Publisher(models.Model): - pass - - @file('apps/myapp2/models.py', make_parent_packages=True) - def apps_myapp2_models(self): - from django.db import models - - class Book(models.Model): - publisher = models.ForeignKey(to='myapp.Publisher', on_delete=models.CASCADE, - related_name='books') - - -class TestOneToOneField(BaseDjangoPluginTestCase): - @env(DJANGO_SETTINGS_MODULE='mysettings') - def test_to_parameter_could_be_specified_as_string(self): - from apps.myapp.models import User - - user = User() - reveal_type(user.profile) # E: Revealed type is 'apps.myapp2.models.Profile' - - @file('mysettings.py') - def mysettings(self): - SECRET_KEY = '112233' - ROOT_DIR = '' - APPS_DIR = '/apps' - - INSTALLED_APPS = ('apps.myapp', 'apps.myapp2') - - @file('apps/myapp/models.py', make_parent_packages=True) - def apps_myapp_models(self): - from django.db import models - - class User(models.Model): - pass - - @file('apps/myapp2/models.py', make_parent_packages=True) - def apps_myapp2_models(self): - from django.db import models - - class Profile(models.Model): - user = models.OneToOneField(to='myapp.User', on_delete=models.CASCADE, - related_name='profile')