From fcd659837e82fc4f7372e5ca4e321fbf3b3aa921 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Sat, 1 Dec 2018 16:26:53 +0300 Subject: [PATCH] cleanups, fix settings --- .../contrib/postgres/fields/array.pyi | 2 +- external/mypy | 1 - mypy_django_plugin/helpers.py | 12 +- mypy_django_plugin/main.py | 76 +---- mypy_django_plugin/monkeypatch/contexts.py | 3 +- mypy_django_plugin/plugins/fields.py | 34 +- .../plugins/meta_inner_class.py | 12 - mypy_django_plugin/plugins/models.py | 180 +++++++++++ .../plugins/objects_queryset.py | 36 --- mypy_django_plugin/plugins/related_fields.py | 47 +-- mypy_django_plugin/plugins/settings.py | 61 ++++ mypy_django_plugin/plugins/setup_settings.py | 42 --- pytest.ini | 5 +- setup.py | 8 +- {test => test-data}/plugins.ini | 0 .../typecheck/fields.test | 22 ++ .../typecheck/models.test | 0 .../typecheck/related_fields.test | 24 +- .../typecheck/settings.test | 10 +- test/__init__.py | 0 test/data.py | 250 --------------- test/helpers.py | 237 -------------- test/pytest_plugin.py | 303 ------------------ test/test-data/postgres-fields.test | 21 -- test/testdjango.py | 60 ---- test/vistir.py | 43 --- 26 files changed, 316 insertions(+), 1173 deletions(-) delete mode 160000 external/mypy delete mode 100644 mypy_django_plugin/plugins/meta_inner_class.py create mode 100644 mypy_django_plugin/plugins/models.py delete mode 100644 mypy_django_plugin/plugins/objects_queryset.py create mode 100644 mypy_django_plugin/plugins/settings.py delete mode 100644 mypy_django_plugin/plugins/setup_settings.py rename {test => test-data}/plugins.ini (100%) rename test/test-data/model-fields.test => test-data/typecheck/fields.test (69%) rename test/test-data/objects-queryset.test => test-data/typecheck/models.test (100%) rename test/test-data/model-relations.test => test-data/typecheck/related_fields.test (81%) rename test/test-data/parse-settings.test => test-data/typecheck/settings.test (89%) delete mode 100644 test/__init__.py delete mode 100644 test/data.py delete mode 100644 test/helpers.py delete mode 100644 test/pytest_plugin.py delete mode 100644 test/test-data/postgres-fields.test delete mode 100644 test/testdjango.py delete mode 100644 test/vistir.py diff --git a/django-stubs/contrib/postgres/fields/array.pyi b/django-stubs/contrib/postgres/fields/array.pyi index 9df579f..5d70923 100644 --- a/django-stubs/contrib/postgres/fields/array.pyi +++ b/django-stubs/contrib/postgres/fields/array.pyi @@ -16,7 +16,7 @@ class ArrayField(CheckFieldDefaultMixin, Field, Generic[_T]): from_db_value: Any = ... def __init__( - self, base_field: Field, size: None = ..., **kwargs: Any + self, base_field: _T, size: None = ..., **kwargs: Any ) -> None: ... @property def model(self): ... diff --git a/external/mypy b/external/mypy deleted file mode 160000 index b790539..0000000 --- a/external/mypy +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b7905398258304bb366539776a36acd74d6f2a10 diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index 693ad49..1ef5de5 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -1,9 +1,9 @@ import typing from typing import Dict, Optional, NamedTuple -from mypy.nodes import SymbolTableNode, Var, Expression, StrExpr, MypyFile, TypeInfo +from mypy.nodes import Expression, StrExpr, MypyFile, TypeInfo from mypy.plugin import FunctionContext -from mypy.types import Type, Instance, UnionType, NoneTyp +from mypy.types import Type, UnionType, NoneTyp MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet' @@ -11,14 +11,6 @@ FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey' ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField' DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject' - -def create_new_symtable_node(name: str, kind: int, instance: Instance) -> SymbolTableNode: - new_var = Var(name, instance) - new_var.info = instance.type - return SymbolTableNode(kind, new_var, - plugin_generated=True) - - Argument = NamedTuple('Argument', fields=[ ('arg', Expression), ('arg_type', Type) diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 88151a1..10b7d24 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -1,81 +1,24 @@ import os -from typing import Callable, Optional, cast +from typing import Callable, Optional -from mypy.nodes import AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr from mypy.options import Options -from mypy.plugin import Plugin, FunctionContext, ClassDefContext -from mypy.semanal import SemanticAnalyzerPass2 -from mypy.types import Type, Instance +from mypy.plugin import Plugin, FunctionContext, ClassDefContext, AnalyzeTypeContext +from mypy.types import Type from mypy_django_plugin import helpers, monkeypatch -from mypy_django_plugin.plugins.meta_inner_class import inject_any_as_base_for_nested_class_meta -from mypy_django_plugin.plugins.objects_queryset import set_objects_queryset_to_model_class -from mypy_django_plugin.plugins.fields import determine_type_of_array_field, \ - add_int_id_attribute_if_primary_key_true_is_not_present -from mypy_django_plugin.plugins.related_fields import set_fieldname_attrs_for_related_fields, add_new_var_node_to_class, \ - extract_to_parameter_as_get_ret_type -from mypy_django_plugin.plugins.setup_settings import DjangoConfSettingsInitializerHook +from mypy_django_plugin.plugins.fields import determine_type_of_array_field +from mypy_django_plugin.plugins.models import process_model_class +from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field +from mypy_django_plugin.plugins.settings import DjangoConfSettingsInitializerHook base_model_classes = {helpers.MODEL_CLASS_FULLNAME} -def add_related_managers_from_referred_foreign_keys_to_model(ctx: ClassDefContext) -> None: - api = cast(SemanticAnalyzerPass2, ctx.api) - for stmt in ctx.cls.defs.body: - if not isinstance(stmt, AssignmentStmt): - continue - if len(stmt.lvalues) > 1: - # not supported yet - continue - rvalue = stmt.rvalue - if not isinstance(rvalue, CallExpr): - continue - if (not isinstance(rvalue.callee, MemberExpr) - or not rvalue.callee.fullname in {helpers.FOREIGN_KEY_FULLNAME, - helpers.ONETOONE_FIELD_FULLNAME}): - continue - if 'related_name' not in rvalue.arg_names: - # positional related_name is not supported yet - continue - related_name = rvalue.args[rvalue.arg_names.index('related_name')].value - - if 'to' in rvalue.arg_names: - expr = rvalue.args[rvalue.arg_names.index('to')] - else: - # first positional argument - expr = rvalue.args[0] - - if isinstance(expr, StrExpr): - model_typeinfo = helpers.get_model_type_from_string(expr, - all_modules=api.modules) - if model_typeinfo is None: - continue - elif isinstance(expr, NameExpr): - model_typeinfo = expr.node - else: - continue - - if rvalue.callee.fullname == helpers.FOREIGN_KEY_FULLNAME: - typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, - args=[Instance(ctx.cls.info, [])]) - else: - typ = Instance(ctx.cls.info, []) - - if typ is None: - continue - add_new_var_node_to_class(model_typeinfo, related_name, typ) - - class TransformModelClassHook(object): def __call__(self, ctx: ClassDefContext) -> None: base_model_classes.add(ctx.cls.fullname) - - set_fieldname_attrs_for_related_fields(ctx) - set_objects_queryset_to_model_class(ctx) - inject_any_as_base_for_nested_class_meta(ctx) - add_related_managers_from_referred_foreign_keys_to_model(ctx) - add_int_id_attribute_if_primary_key_true_is_not_present(ctx) + process_model_class(ctx) class DjangoPlugin(Plugin): @@ -89,7 +32,6 @@ class DjangoPlugin(Plugin): if self.django_settings: monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(self.django_settings) monkeypatch.inject_dependencies(self.django_settings) - # monkeypatch.process_settings_before_dependants(self.django_settings) else: monkeypatch.restore_original_load_graph() monkeypatch.restore_original_dependencies_handling() @@ -98,7 +40,7 @@ class DjangoPlugin(Plugin): ) -> Optional[Callable[[FunctionContext], Type]]: if fullname in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}: - return extract_to_parameter_as_get_ret_type + return extract_to_parameter_as_get_ret_type_for_related_field # if fullname == helpers.ONETOONE_FIELD_FULLNAME: # return OneToOneFieldHook(settings=self.django_settings) diff --git a/mypy_django_plugin/monkeypatch/contexts.py b/mypy_django_plugin/monkeypatch/contexts.py index 7d926ca..740c811 100644 --- a/mypy_django_plugin/monkeypatch/contexts.py +++ b/mypy_django_plugin/monkeypatch/contexts.py @@ -1,7 +1,7 @@ from typing import Optional, List, Sequence, NamedTuple, Tuple from mypy import checkexpr -from mypy.argmap import map_actuals_to_formals +from mypy.checkexpr import map_actuals_to_formals from mypy.checkmember import analyze_member_access from mypy.expandtype import freshen_function_type_vars from mypy.messages import MessageBuilder @@ -68,7 +68,6 @@ class PatchedExpressionChecker(checkexpr.ExpressionChecker): on which the method is being called """ arg_messages = arg_messages or self.msg - if isinstance(callee, CallableType): if callable_name is None and callee.name: callable_name = callee.name diff --git a/mypy_django_plugin/plugins/fields.py b/mypy_django_plugin/plugins/fields.py index 836d86f..25e4171 100644 --- a/mypy_django_plugin/plugins/fields.py +++ b/mypy_django_plugin/plugins/fields.py @@ -1,11 +1,5 @@ -from typing import Iterator, List, cast - -from mypy.nodes import ClassDef, AssignmentStmt, CallExpr -from mypy.plugin import FunctionContext, ClassDefContext -from mypy.semanal import SemanticAnalyzerPass2 -from mypy.types import Type, Instance - -from mypy_django_plugin.plugins.related_fields import add_new_var_node_to_class +from mypy.plugin import FunctionContext +from mypy.types import Type def determine_type_of_array_field(ctx: FunctionContext) -> Type: @@ -15,27 +9,3 @@ def determine_type_of_array_field(ctx: FunctionContext) -> Type: base_field_arg_type = ctx.arg_types[ctx.arg_names.index('base_field')][0] return ctx.api.named_generic_type(ctx.context.callee.fullname, args=[base_field_arg_type.type.names['__get__'].type.ret_type]) - - -def get_assignments(klass: ClassDef) -> List[AssignmentStmt]: - stmts = [] - for stmt in klass.defs.body: - if not isinstance(stmt, AssignmentStmt): - continue - if len(stmt.lvalues) > 1: - # not supported yet - continue - stmts.append(stmt) - return stmts - - -def add_int_id_attribute_if_primary_key_true_is_not_present(ctx: ClassDefContext) -> None: - api = cast(SemanticAnalyzerPass2, ctx.api) - for stmt in get_assignments(ctx.cls): - if (isinstance(stmt.rvalue, CallExpr) - and 'primary_key' in stmt.rvalue.arg_names - and api.parse_bool(stmt.rvalue.args[stmt.rvalue.arg_names.index('primary_key')])): - break - else: - add_new_var_node_to_class(ctx.cls.info, 'id', api.builtin_type('builtins.int')) - diff --git a/mypy_django_plugin/plugins/meta_inner_class.py b/mypy_django_plugin/plugins/meta_inner_class.py deleted file mode 100644 index b37b159..0000000 --- a/mypy_django_plugin/plugins/meta_inner_class.py +++ /dev/null @@ -1,12 +0,0 @@ -from mypy.nodes import TypeInfo -from mypy.plugin import ClassDefContext - - -def inject_any_as_base_for_nested_class_meta(ctx: ClassDefContext) -> None: - if 'Meta' not in ctx.cls.info.names: - return None - sym = ctx.cls.info.names['Meta'] - if not isinstance(sym.node, TypeInfo): - return None - - sym.node.fallback_to_any = True diff --git a/mypy_django_plugin/plugins/models.py b/mypy_django_plugin/plugins/models.py new file mode 100644 index 0000000..3859a1b --- /dev/null +++ b/mypy_django_plugin/plugins/models.py @@ -0,0 +1,180 @@ +from typing import cast, Iterator, Tuple, Optional + +from mypy.nodes import ClassDef, AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr, MDEF, TypeInfo, Var, SymbolTableNode, \ + Lvalue, Expression, Statement +from mypy.plugin import ClassDefContext +from mypy.semanal import SemanticAnalyzerPass2 +from mypy.types import Instance + +from mypy_django_plugin import helpers + + +def add_new_var_node_to_class(class_type: TypeInfo, name: str, typ: Instance) -> None: + var = Var(name=name, type=typ) + var.info = typ.type + var._fullname = class_type.fullname() + '.' + name + var.is_inferred = True + var.is_initialized_in_class = True + class_type.names[name] = SymbolTableNode(MDEF, var) + + +def iter_over_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, Expression]]: + for stmt in klass.defs.body: + if not isinstance(stmt, AssignmentStmt): + continue + if len(stmt.lvalues) > 1: + # not supported yet + continue + yield stmt.lvalues[0], stmt.rvalue + + +def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]: + for lvalue, rvalue in iter_over_assignments(klass): + if not isinstance(rvalue, CallExpr): + continue + yield lvalue, rvalue + + +def iter_over_one_to_n_related_fields(klass: ClassDef, api: SemanticAnalyzerPass2) -> Iterator[Tuple[NameExpr, CallExpr]]: + for lvalue, rvalue in iter_call_assignments(klass): + if (isinstance(lvalue, NameExpr) + and isinstance(rvalue.callee, MemberExpr)): + if rvalue.callee.fullname in {helpers.FOREIGN_KEY_FULLNAME, + helpers.ONETOONE_FIELD_FULLNAME}: + yield lvalue, rvalue + + +def get_nested_meta_class(model_type: TypeInfo) -> Optional[TypeInfo]: + metaclass_sym = model_type.names.get('Meta') + if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo): + return metaclass_sym.node + return None + + +def is_abstract_model(ctx: ClassDefContext) -> bool: + meta_node = get_nested_meta_class(ctx.cls.info) + if meta_node is None: + return False + + for lvalue, rvalue in iter_over_assignments(meta_node.defn): + if isinstance(lvalue, NameExpr) and lvalue.name == 'abstract': + is_abstract = ctx.api.parse_bool(rvalue) + if is_abstract: + # abstract model do not need 'objects' queryset + return True + return False + + +def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None: + api = ctx.api + for lvalue, rvalue in iter_over_one_to_n_related_fields(ctx.cls, api): + property_name = lvalue.name + '_id' + add_new_var_node_to_class(ctx.cls.info, property_name, + typ=api.named_type('__builtins__.int')) + + +def add_int_id_attribute_if_primary_key_true_is_not_present(ctx: ClassDefContext) -> None: + api = cast(SemanticAnalyzerPass2, ctx.api) + if is_abstract_model(ctx): + return None + + for _, rvalue in iter_call_assignments(ctx.cls): + if ('primary_key' in rvalue.arg_names and + api.parse_bool(rvalue.args[rvalue.arg_names.index('primary_key')])): + break + else: + add_new_var_node_to_class(ctx.cls.info, 'id', api.builtin_type('builtins.int')) + + +def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None: + # search over mro + objects_sym = ctx.cls.info.get('objects') + if objects_sym is not None: + return None + + # only direct Meta class + if is_abstract_model(ctx): + # abstract model do not need 'objects' queryset + return None + + api = cast(SemanticAnalyzerPass2, ctx.api) + typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, + args=[Instance(ctx.cls.info, [])]) + if not typ: + return None + add_new_var_node_to_class(ctx.cls.info, 'objects', typ=typ) + + +def inject_any_as_base_for_nested_class_meta(ctx: ClassDefContext) -> None: + meta_node = get_nested_meta_class(ctx.cls.info) + if meta_node is None: + return None + meta_node.fallback_to_any = True + + +def is_model_defn(defn: Statement, api: SemanticAnalyzerPass2) -> bool: + if not isinstance(defn, ClassDef): + return False + + for base_type_expr in defn.base_type_exprs: + # api.accept(base_type_expr) + fullname = getattr(base_type_expr, 'fullname', None) + if fullname == helpers.MODEL_CLASS_FULLNAME: + return True + return False + + +def iter_over_models(ctx: ClassDefContext) -> Iterator[ClassDef]: + for module_name, module_file in ctx.api.modules.items(): + for defn in module_file.defs: + if is_model_defn(defn, api=cast(SemanticAnalyzerPass2, ctx.api)): + yield defn + + +def extract_to_value_or_none(field_expr: CallExpr, ctx: ClassDefContext) -> Optional[TypeInfo]: + if 'to' in field_expr.arg_names: + ref_expr = field_expr.args[field_expr.arg_names.index('to')] + else: + # first positional argument + ref_expr = field_expr.args[0] + + if isinstance(ref_expr, StrExpr): + model_typeinfo = helpers.get_model_type_from_string(ref_expr, + all_modules=ctx.api.modules) + return model_typeinfo + elif isinstance(ref_expr, NameExpr): + return ref_expr.node + + +def get_related_field_type(rvalue: CallExpr, api: SemanticAnalyzerPass2, + related_model_typ: TypeInfo) -> Optional[Instance]: + if rvalue.callee.fullname == helpers.FOREIGN_KEY_FULLNAME: + return api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, + args=[Instance(related_model_typ, [])]) + else: + return Instance(related_model_typ, []) + + +def add_related_managers(ctx: ClassDefContext) -> None: + for model_defn in iter_over_models(ctx): + for _, rvalue in iter_over_one_to_n_related_fields(model_defn, ctx.api): + if 'related_name' not in rvalue.arg_names: + # positional related_name is not supported yet + return + related_name = rvalue.args[rvalue.arg_names.index('related_name')].value + ref_to_typ = extract_to_value_or_none(rvalue, ctx) + if ref_to_typ is not None: + if ref_to_typ.fullname() == ctx.cls.info.fullname(): + typ = get_related_field_type(rvalue, ctx.api, + related_model_typ=model_defn.info) + if typ is None: + return + add_new_var_node_to_class(ctx.cls.info, related_name, typ) + + +def process_model_class(ctx: ClassDefContext) -> None: + # add_related_managers(ctx) + inject_any_as_base_for_nested_class_meta(ctx) + set_fieldname_attrs_for_related_fields(ctx) + add_int_id_attribute_if_primary_key_true_is_not_present(ctx) + set_objects_queryset_to_model_class(ctx) diff --git a/mypy_django_plugin/plugins/objects_queryset.py b/mypy_django_plugin/plugins/objects_queryset.py deleted file mode 100644 index f8cf49d..0000000 --- a/mypy_django_plugin/plugins/objects_queryset.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import cast - -from mypy.nodes import MDEF, AssignmentStmt -from mypy.plugin import ClassDefContext -from mypy.semanal import SemanticAnalyzerPass2 -from mypy.types import Instance - -from mypy_django_plugin import helpers - - -def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None: - # search over mro - objects_sym = ctx.cls.info.get('objects') - if objects_sym is not None: - return None - - # only direct Meta class - metaclass_sym = ctx.cls.info.names.get('Meta') - # skip if abstract - if metaclass_sym is not None: - for stmt in metaclass_sym.node.defn.defs.body: - if (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1 - and stmt.lvalues[0].name == 'abstract'): - is_abstract = ctx.api.parse_bool(stmt.rvalue) - if is_abstract: - return None - - api = cast(SemanticAnalyzerPass2, ctx.api) - typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, - args=[Instance(ctx.cls.info, [])]) - if not typ: - return None - - ctx.cls.info.names['objects'] = helpers.create_new_symtable_node('objects', - kind=MDEF, - instance=typ) diff --git a/mypy_django_plugin/plugins/related_fields.py b/mypy_django_plugin/plugins/related_fields.py index aac0791..2a3c72c 100644 --- a/mypy_django_plugin/plugins/related_fields.py +++ b/mypy_django_plugin/plugins/related_fields.py @@ -1,18 +1,12 @@ import typing from typing import Optional, cast -from django.conf import Settings from mypy.checker import TypeChecker -from mypy.nodes import MDEF, AssignmentStmt, MypyFile, StrExpr, TypeInfo, NameExpr, Var, SymbolTableNode -from mypy.plugin import FunctionContext, ClassDefContext +from mypy.nodes import StrExpr +from mypy.plugin import FunctionContext from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny from mypy_django_plugin import helpers -from mypy_django_plugin.helpers import get_models_file - - -def extract_related_name_value(ctx: FunctionContext) -> str: - return ctx.args[ctx.arg_names.index('related_name')][0].value def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]): @@ -55,44 +49,9 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: return referred_to_type -def add_new_var_node_to_class(class_type: TypeInfo, name: str, typ: Instance) -> None: - var = Var(name=name, type=typ) - var.info = typ.type - var._fullname = class_type.fullname() + '.' + name - var.is_inferred = True - var.is_initialized_in_class = True - class_type.names[name] = SymbolTableNode(MDEF, var) - - -def extract_to_parameter_as_get_ret_type(ctx: FunctionContext) -> Type: +def extract_to_parameter_as_get_ret_type_for_related_field(ctx: FunctionContext) -> Type: referred_to_type = get_valid_to_value_or_none(ctx) if referred_to_type is None: # couldn't extract to= value return fill_typevars_with_any(ctx.default_return_type) return reparametrize_with(ctx.default_return_type, [referred_to_type]) - - -def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None: - api = ctx.api - for stmt in ctx.cls.defs.body: - if not isinstance(stmt, AssignmentStmt): - continue - if not hasattr(stmt.rvalue, 'callee'): - continue - if len(stmt.lvalues) > 1: - # multiple lvalues not supported for now - continue - - expr = stmt.lvalues[0] - if not isinstance(expr, NameExpr): - continue - name = expr.name - - rvalue_callee = stmt.rvalue.callee - if rvalue_callee.fullname in {helpers.FOREIGN_KEY_FULLNAME, - helpers.ONETOONE_FIELD_FULLNAME}: - name += '_id' - new_node = helpers.create_new_symtable_node(name, - kind=MDEF, - instance=api.named_type('__builtins__.int')) - ctx.cls.info.names[name] = new_node diff --git a/mypy_django_plugin/plugins/settings.py b/mypy_django_plugin/plugins/settings.py new file mode 100644 index 0000000..08bffa3 --- /dev/null +++ b/mypy_django_plugin/plugins/settings.py @@ -0,0 +1,61 @@ +from typing import cast, List + +from mypy.nodes import Var, Context, SymbolNode, SymbolTableNode +from mypy.plugin import ClassDefContext +from mypy.semanal import SemanticAnalyzerPass2 +from mypy.types import Instance, UnionType, NoneTyp, Type + +from mypy_django_plugin import helpers + + +def get_error_context(node: SymbolNode) -> Context: + context = Context() + context.set_line(node) + return context + + +def filter_out_nones(typ: UnionType) -> List[Type]: + return [item for item in typ.items if not isinstance(item, NoneTyp)] + + +def copy_sym_of_instance(sym: SymbolTableNode) -> SymbolTableNode: + copied = sym.copy() + copied.node.info = sym.type.type + return copied + + +def add_settings_to_django_conf_object(ctx: ClassDefContext, + settings_module: str) -> None: + api = cast(SemanticAnalyzerPass2, ctx.api) + if settings_module not in api.modules: + return None + + settings_file = api.modules[settings_module] + for name, sym in settings_file.names.items(): + if name.isupper() and isinstance(sym.node, Var): + if isinstance(sym.type, Instance): + copied = sym.copy() + copied.node.info = sym.type.type + ctx.cls.info.names[name] = copied + + elif isinstance(sym.type, UnionType): + instances = filter_out_nones(sym.type) + if len(instances) > 1: + # plain unions not supported yet + continue + typ = instances[0] + if isinstance(typ, Instance): + copied = sym.copy() + copied.node.info = typ.type + ctx.cls.info.names[name] = copied + + +class DjangoConfSettingsInitializerHook(object): + def __init__(self, settings_module: str): + self.settings_module = settings_module + + def __call__(self, ctx: ClassDefContext) -> None: + if not self.settings_module: + return + + add_settings_to_django_conf_object(ctx, self.settings_module) diff --git a/mypy_django_plugin/plugins/setup_settings.py b/mypy_django_plugin/plugins/setup_settings.py deleted file mode 100644 index 2aaf4ec..0000000 --- a/mypy_django_plugin/plugins/setup_settings.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Optional, Any, cast - -from mypy.nodes import Var, Context, GDEF -from mypy.options import Options -from mypy.plugin import ClassDefContext -from mypy.semanal import SemanticAnalyzerPass2 -from mypy.types import Instance - - -def add_settings_to_django_conf_object(ctx: ClassDefContext, - settings_module: str) -> Optional[Any]: - api = cast(SemanticAnalyzerPass2, ctx.api) - if settings_module not in api.modules: - return None - - settings_file = api.modules[settings_module] - for name, sym in settings_file.names.items(): - if name.isupper(): - if not isinstance(sym.node, Var) or not isinstance(sym.type, Instance): - error_context = Context() - error_context.set_line(sym.node) - api.msg.fail("Need type annotation for '{}'".format(sym.node.name()), - context=error_context, - file=settings_file.path, - origin=Context()) - continue - - sym_copy = sym.copy() - sym_copy.node.info = sym_copy.type.type - sym_copy.kind = GDEF - ctx.cls.info.names[name] = sym_copy - - -class DjangoConfSettingsInitializerHook(object): - def __init__(self, settings_module: str): - self.settings_module = settings_module - - def __call__(self, ctx: ClassDefContext) -> None: - if not self.settings_module: - return - - add_settings_to_django_conf_object(ctx, self.settings_module) diff --git a/pytest.ini b/pytest.ini index 4291613..3a56481 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,5 +4,6 @@ testpaths = ./test python_files = test*.py addopts = --tb=native - --ignore=./external - --mypy-ini-file=./test/plugins.ini \ No newline at end of file + --mypy-ini-file=./test-data/plugins.ini + -s + -v \ No newline at end of file diff --git a/setup.py b/setup.py index 0b69c87..36ef80a 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,9 @@ setup( author_email="maxim.kurnikov@gmail.com", version="0.1.0", license='BSD', - install_requires=['Django>=2.1.1'], - packages=['mypy_django_plugin'] - # package_data=find_stubs('django-stubs') + install_requires=[ + 'Django>=2.1.1', + 'mypy' + ], + packages=['mypy_django_plugin'], ) diff --git a/test/plugins.ini b/test-data/plugins.ini similarity index 100% rename from test/plugins.ini rename to test-data/plugins.ini diff --git a/test/test-data/model-fields.test b/test-data/typecheck/fields.test similarity index 69% rename from test/test-data/model-fields.test rename to test-data/typecheck/fields.test index 09dfa1f..fad8b1f 100644 --- a/test/test-data/model-fields.test +++ b/test-data/typecheck/fields.test @@ -1,3 +1,25 @@ +[CASE array_field_descriptor_access] +from django.db import models +from django.contrib.postgres.fields import ArrayField + +class User(models.Model): + array = ArrayField(base_field=models.Field()) + +user = User() +reveal_type(user.array) # E: Revealed type is 'builtins.list[Any]' + +[CASE array_field_base_field_parsed_into_generic_typevar] +from django.db import models +from django.contrib.postgres.fields import ArrayField + +class User(models.Model): + members = ArrayField(base_field=models.IntegerField()) + members_as_text = ArrayField(base_field=models.CharField(max_length=255)) + +user = User() +reveal_type(user.members) # E: Revealed type is 'builtins.list[builtins.int*]' +reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list[builtins.str*]' + [CASE test_model_fields_classes_present_as_primitives] from django.db import models diff --git a/test/test-data/objects-queryset.test b/test-data/typecheck/models.test similarity index 100% rename from test/test-data/objects-queryset.test rename to test-data/typecheck/models.test diff --git a/test/test-data/model-relations.test b/test-data/typecheck/related_fields.test similarity index 81% rename from test/test-data/model-relations.test rename to test-data/typecheck/related_fields.test index abe9544..97dbb85 100644 --- a/test/test-data/model-relations.test +++ b/test-data/typecheck/related_fields.test @@ -105,7 +105,12 @@ class Profile(models.Model): from django.db import models from myapp.models import App class View(models.Model): - app = models.ForeignKey(to='myapp.App', related_name='views', on_delete=models.CASCADE) + app = models.ForeignKey(to=App, related_name='views', on_delete=models.CASCADE) + +reveal_type(View().app.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' +reveal_type(View().app.unknown) # E: Revealed type is 'Any' +[out] +main:7: error: "App" has no attribute "unknown" [file myapp/__init__.py] [file myapp/models.py] @@ -113,3 +118,20 @@ from django.db import models class App(models.Model): def method(self) -> None: reveal_type(self.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' + +[case models_related_managers_work_with_direct_model_inheritance_and_with_inheritance_from_other_model] +from django.db.models import Model +from django.db import models + +class App(Model): + pass + +class View(Model): + app = models.ForeignKey(to=App, on_delete=models.CASCADE, related_name='views') + +class View2(View): + app = models.ForeignKey(to=App, on_delete=models.CASCADE, related_name='views2') + +reveal_type(App().views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' +reveal_type(App().views2) # E: Revealed type is 'django.db.models.query.QuerySet[main.View2]' +[out] \ No newline at end of file diff --git a/test/test-data/parse-settings.test b/test-data/typecheck/settings.test similarity index 89% rename from test/test-data/parse-settings.test rename to test-data/typecheck/settings.test index 8270149..9ba185e 100644 --- a/test/test-data/parse-settings.test +++ b/test-data/typecheck/settings.test @@ -11,9 +11,7 @@ SECRET_KEY = 112233 ROOT_DIR = '/etc' NUMBERS = ['one', 'two'] DICT = {} # type: ignore - from django.utils.functional import LazyObject - OBJ = LazyObject() [CASE test_settings_could_be_defined_in_different_module_and_imported_with_star] @@ -36,18 +34,18 @@ ROOT_DIR = Path(__file__) [CASE test_circular_dependency_in_settings] from django.conf import settings - class Class: pass - reveal_type(settings.MYSETTING) # E: Revealed type is 'builtins.int' -reveal_type(settings.REGISTRY) # E: Revealed type is 'Any' +reveal_type(settings.REGISTRY) # E: Revealed type is 'Union[main.Class, None]' +[env DJANGO_SETTINGS_MODULE=mysettings] [file mysettings.py] from typing import TYPE_CHECKING +from typing import Optional if TYPE_CHECKING: - from .main import Class + from main import Class MYSETTING = 1122 REGISTRY: Optional['Class'] = None diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/test/data.py b/test/data.py deleted file mode 100644 index 2e5abf0..0000000 --- a/test/data.py +++ /dev/null @@ -1,250 +0,0 @@ -import os -import posixpath -import re -import sys -import tempfile -from typing import Any, Optional, Iterator, Dict, List, Tuple, Set - -import pytest -from mypy.test.config import test_temp_dir -from mypy.test.data import DataDrivenTestCase, DataSuite, add_test_name_suffix, parse_test_data, \ - expand_errors, expand_variables, fix_win_path - - -def parse_test_case(case: 'DataDrivenTestCase') -> None: - """Parse and prepare a single case from suite with test case descriptions. - - This method is part of the setup phase, just before the test case is run. - """ - test_items = parse_test_data(case.data, case.name) - base_path = case.suite.base_path - if case.suite.native_sep: - join = os.path.join - else: - join = posixpath.join # type: ignore - - out_section_missing = case.suite.required_out_section - - files = [] # type: List[Tuple[str, str]] # path and contents - output_files = [] # type: List[Tuple[str, str]] # path and contents for output files - output = [] # type: List[str] # Regular output errors - output2 = {} # type: Dict[int, List[str]] # Output errors for incremental, runs 2+ - deleted_paths = {} # type: Dict[int, Set[str]] # from run number of paths - stale_modules = {} # type: Dict[int, Set[str]] # from run number to module names - rechecked_modules = {} # type: Dict[ int, Set[str]] # from run number module names - triggered = [] # type: List[str] # Active triggers (one line per incremental step) - - # Process the parsed items. Each item has a header of form [id args], - # optionally followed by lines of text. - item = first_item = test_items[0] - for item in test_items[1:]: - if item.id == 'file' or item.id == 'outfile': - # Record an extra file needed for the test case. - assert item.arg is not None - contents = expand_variables('\n'.join(item.data)) - file_entry = (join(base_path, item.arg), contents) - if item.id == 'file': - files.append(file_entry) - else: - output_files.append(file_entry) - elif item.id in ('builtins', 'builtins_py2'): - # Use an alternative stub file for the builtins module. - assert item.arg is not None - mpath = join(os.path.dirname(case.file), item.arg) - fnam = 'builtins.pyi' if item.id == 'builtins' else '__builtin__.pyi' - with open(mpath) as f: - files.append((join(base_path, fnam), f.read())) - elif item.id == 'typing': - # Use an alternative stub file for the typing module. - assert item.arg is not None - src_path = join(os.path.dirname(case.file), item.arg) - with open(src_path) as f: - files.append((join(base_path, 'typing.pyi'), f.read())) - elif re.match(r'stale[0-9]*$', item.id): - passnum = 1 if item.id == 'stale' else int(item.id[len('stale'):]) - assert passnum > 0 - modules = (set() if item.arg is None else {t.strip() for t in item.arg.split(',')}) - stale_modules[passnum] = modules - elif re.match(r'rechecked[0-9]*$', item.id): - passnum = 1 if item.id == 'rechecked' else int(item.id[len('rechecked'):]) - assert passnum > 0 - modules = (set() if item.arg is None else {t.strip() for t in item.arg.split(',')}) - rechecked_modules[passnum] = modules - elif item.id == 'delete': - # File to delete during a multi-step test case - assert item.arg is not None - m = re.match(r'(.*)\.([0-9]+)$', item.arg) - assert m, 'Invalid delete section: {}'.format(item.arg) - num = int(m.group(2)) - assert num >= 2, "Can't delete during step {}".format(num) - full = join(base_path, m.group(1)) - deleted_paths.setdefault(num, set()).add(full) - elif re.match(r'out[0-9]*$', item.id): - tmp_output = [expand_variables(line) for line in item.data] - if os.path.sep == '\\': - tmp_output = [fix_win_path(line) for line in tmp_output] - if item.id == 'out' or item.id == 'out1': - output = tmp_output - else: - passnum = int(item.id[len('out'):]) - assert passnum > 1 - output2[passnum] = tmp_output - out_section_missing = False - elif item.id == 'triggered' and item.arg is None: - triggered = item.data - elif item.id == 'env': - env_vars_to_set = item.arg - for env in env_vars_to_set.split(';'): - try: - name, value = env.split('=') - os.environ[name] = value - except ValueError: - continue - else: - raise ValueError( - 'Invalid section header {} in {} at line {}'.format( - item.id, case.file, item.line)) - - if out_section_missing: - raise ValueError( - '{}, line {}: Required output section not found'.format( - case.file, first_item.line)) - - for passnum in stale_modules.keys(): - if passnum not in rechecked_modules: - # If the set of rechecked modules isn't specified, make it the same as the set - # of modules with a stale public interface. - rechecked_modules[passnum] = stale_modules[passnum] - if (passnum in stale_modules - and passnum in rechecked_modules - and not stale_modules[passnum].issubset(rechecked_modules[passnum])): - raise ValueError( - ('Stale modules after pass {} must be a subset of rechecked ' - 'modules ({}:{})').format(passnum, case.file, first_item.line)) - - input = first_item.data - expand_errors(input, output, 'main') - for file_path, contents in files: - expand_errors(contents.split('\n'), output, file_path) - - case.input = input - case.output = output - case.output2 = output2 - case.lastline = item.line - case.files = files - case.output_files = output_files - case.expected_stale_modules = stale_modules - case.expected_rechecked_modules = rechecked_modules - case.deleted_paths = deleted_paths - case.triggered = triggered or [] - - -class DjangoDataDrivenTestCase(DataDrivenTestCase): - def setup(self) -> None: - self.old_environ = os.environ.copy() - - parse_test_case(case=self) - self.old_cwd = os.getcwd() - - self.tmpdir = tempfile.TemporaryDirectory(prefix='mypy-test-') - tmpdir_root = os.path.join(self.tmpdir.name, 'tmp') - - new_files = [] - for path, contents in self.files: - new_files.append((path, contents.replace('', tmpdir_root))) - self.files = new_files - - os.chdir(self.tmpdir.name) - os.mkdir(test_temp_dir) - encountered_files = set() - self.clean_up = [] - for paths in self.deleted_paths.values(): - for path in paths: - self.clean_up.append((False, path)) - encountered_files.add(path) - for path, content in self.files: - dir = os.path.dirname(path) - for d in self.add_dirs(dir): - self.clean_up.append((True, d)) - with open(path, 'w') as f: - f.write(content) - if path not in encountered_files: - self.clean_up.append((False, path)) - encountered_files.add(path) - if re.search(r'\.[2-9]$', path): - # Make sure new files introduced in the second and later runs are accounted for - renamed_path = path[:-2] - if renamed_path not in encountered_files: - encountered_files.add(renamed_path) - self.clean_up.append((False, renamed_path)) - for path, _ in self.output_files: - # Create directories for expected output and mark them to be cleaned up at the end - # of the test case. - dir = os.path.dirname(path) - for d in self.add_dirs(dir): - self.clean_up.append((True, d)) - self.clean_up.append((False, path)) - - sys.path.insert(0, tmpdir_root) - - def teardown(self): - if hasattr(self, 'old_environ'): - os.environ = self.old_environ - super().teardown() - - -def split_test_cases(parent: 'DataSuiteCollector', suite: 'DataSuite', - file: str) -> Iterator[DjangoDataDrivenTestCase]: - """Iterate over raw test cases in file, at collection time, ignoring sub items. - - The collection phase is slow, so any heavy processing should be deferred to after - uninteresting tests are filtered (when using -k PATTERN switch). - """ - with open(file, encoding='utf-8') as f: - data = f.read() - cases = re.split(r'^\[case ([a-zA-Z_0-9]+)' - r'(-writescache)?' - r'(-only_when_cache|-only_when_nocache)?' - r'(-skip)?' - r'\][ \t]*$\n', data, - flags=re.DOTALL | re.MULTILINE) - line_no = cases[0].count('\n') + 1 - - for i in range(1, len(cases), 5): - name, writescache, only_when, skip, data = cases[i:i + 5] - yield DjangoDataDrivenTestCase(parent, suite, file, - name=add_test_name_suffix(name, suite.test_name_suffix), - writescache=bool(writescache), - only_when=only_when, - skip=bool(skip), - data=data, - line=line_no) - line_no += data.count('\n') + 1 - - -class DataSuiteCollector(pytest.Class): # type: ignore # inheriting from Any - def collect(self) -> Iterator[pytest.Item]: # type: ignore - """Called by pytest on each of the object returned from pytest_pycollect_makeitem""" - - # obj is the object for which pytest_pycollect_makeitem returned self. - suite = self.obj # type: DataSuite - for f in suite.files: - yield from split_test_cases(self, suite, os.path.join(suite.data_prefix, f)) - - -# This function name is special to pytest. See -# http://doc.pytest.org/en/latest/writing_plugins.html#collection-hooks -def pytest_pycollect_makeitem(collector: Any, name: str, - obj: object) -> 'Optional[Any]': - """Called by pytest on each object in modules configured in conftest.py files. - - collector is pytest.Collector, returns Optional[pytest.Class] - """ - if isinstance(obj, type): - # Only classes derived from DataSuite contain test cases, not the DataSuite class itself - if issubclass(obj, DataSuite) and obj is not DataSuite: - # Non-None result means this obj is a test case. - # The collect method of the returned DataSuiteCollector instance will be called later, - # with self.obj being obj. - return DataSuiteCollector(name, parent=collector) - return None diff --git a/test/helpers.py b/test/helpers.py deleted file mode 100644 index b41a99a..0000000 --- a/test/helpers.py +++ /dev/null @@ -1,237 +0,0 @@ -import inspect -import os -import re -from typing import List, Callable, Optional, Tuple - -import pytest # type: ignore # no pytest in typeshed - -skip = pytest.mark.skip - -# AssertStringArraysEqual displays special line alignment helper messages if -# the first different line has at least this many characters, -MIN_LINE_LENGTH_FOR_ALIGNMENT = 5 - - -class TypecheckAssertionError(AssertionError): - def __init__(self, error_message: str, lineno: int): - self.error_message = error_message - self.lineno = lineno - - def first_line(self): - return self.__class__.__name__ + '(message="Invalid output")' - - def __str__(self): - return self.error_message - - -def _clean_up(a: List[str]) -> List[str]: - """Remove common directory prefix from all strings in a. - - This uses a naive string replace; it seems to work well enough. Also - remove trailing carriage returns. - """ - res = [] - for s in a: - prefix = os.sep - ss = s - for p in prefix, prefix.replace(os.sep, '/'): - if p != '/' and p != '//' and p != '\\' and p != '\\\\': - ss = ss.replace(p, '') - # Ignore spaces at end of line. - ss = re.sub(' +$', '', ss) - res.append(re.sub('\\r$', '', ss)) - return res - - -def _num_skipped_prefix_lines(a1: List[str], a2: List[str]) -> int: - num_eq = 0 - while num_eq < min(len(a1), len(a2)) and a1[num_eq] == a2[num_eq]: - num_eq += 1 - return max(0, num_eq - 4) - - -def _num_skipped_suffix_lines(a1: List[str], a2: List[str]) -> int: - num_eq = 0 - while (num_eq < min(len(a1), len(a2)) - and a1[-num_eq - 1] == a2[-num_eq - 1]): - num_eq += 1 - return max(0, num_eq - 4) - - -def _add_aligned_message(s1: str, s2: str, error_message: str) -> str: - """Align s1 and s2 so that the their first difference is highlighted. - - For example, if s1 is 'foobar' and s2 is 'fobar', display the - following lines: - - E: foobar - A: fobar - ^ - - If s1 and s2 are long, only display a fragment of the strings around the - first difference. If s1 is very short, do nothing. - """ - - # Seeing what went wrong is trivial even without alignment if the expected - # string is very short. In this case do nothing to simplify output. - if len(s1) < 4: - return error_message - - maxw = 72 # Maximum number of characters shown - - error_message += 'Alignment of first line difference:\n' - # sys.stderr.write('Alignment of first line difference:\n') - - trunc = False - while s1[:30] == s2[:30]: - s1 = s1[10:] - s2 = s2[10:] - trunc = True - - if trunc: - s1 = '...' + s1 - s2 = '...' + s2 - - max_len = max(len(s1), len(s2)) - extra = '' - if max_len > maxw: - extra = '...' - - # Write a chunk of both lines, aligned. - error_message += ' E: {}{}\n'.format(s1[:maxw], extra) - # sys.stderr.write(' E: {}{}\n'.format(s1[:maxw], extra)) - error_message += ' A: {}{}\n'.format(s2[:maxw], extra) - # sys.stderr.write(' A: {}{}\n'.format(s2[:maxw], extra)) - # Write an indicator character under the different columns. - error_message += ' ' - # sys.stderr.write(' ') - for j in range(min(maxw, max(len(s1), len(s2)))): - if s1[j:j + 1] != s2[j:j + 1]: - error_message += '^' - # sys.stderr.write('^') # Difference - break - else: - error_message += ' ' - # sys.stderr.write(' ') # Equal - error_message += '\n' - return error_message - # sys.stderr.write('\n') - - -def assert_string_arrays_equal(expected: List[str], actual: List[str]) -> None: - """Assert that two string arrays are equal. - - Display any differences in a human-readable form. - """ - - actual = _clean_up(actual) - error_message = '' - - if set(actual) != set(expected): - num_skip_start = _num_skipped_prefix_lines(expected, actual) - num_skip_end = _num_skipped_suffix_lines(expected, actual) - - error_message += 'Expected:\n' - - # If omit some lines at the beginning, indicate it by displaying a line - # with '...'. - if num_skip_start > 0: - error_message += ' ...\n' - - # Keep track of the first different line. - first_diff = -1 - - # Display only this many first characters of identical lines. - width = 75 - - for i in range(num_skip_start, len(expected) - num_skip_end): - if i >= len(actual) or expected[i] != actual[i]: - if first_diff < 0: - first_diff = i - error_message += ' {:<45} (diff)'.format(expected[i]) - else: - e = expected[i] - error_message += ' ' + e[:width] - if len(e) > width: - error_message += '...' - error_message += '\n' - if num_skip_end > 0: - error_message += ' ...\n' - - error_message += 'Actual:\n' - - if num_skip_start > 0: - error_message += ' ...\n' - - for j in range(num_skip_start, len(actual) - num_skip_end): - if j >= len(expected) or expected[j] != actual[j]: - error_message += ' {:<45} (diff)'.format(actual[j]) - else: - a = actual[j] - error_message += ' ' + a[:width] - if len(a) > width: - error_message += '...' - error_message += '\n' - if actual == []: - error_message += ' (empty)\n' - if num_skip_end > 0: - error_message += ' ...\n' - - error_message += '\n' - - if 0 <= first_diff < len(actual) and ( - len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT - or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT): - # Display message that helps visualize the differences between two - # long lines. - error_message = _add_aligned_message(expected[first_diff], actual[first_diff], - error_message) - - first_failure = expected[first_diff] - if first_failure: - lineno = int(first_failure.split(' ')[0].strip(':').split(':')[1]) - raise TypecheckAssertionError(error_message=f'Invalid output: \n{error_message}', - lineno=lineno) - - -def build_output_line(fname: str, lnum: int, severity: str, message: str, col=None) -> str: - if col is None: - return f'{fname}:{lnum + 1}: {severity}: {message}' - else: - return f'{fname}:{lnum + 1}:{col}: {severity}: {message}' - - -def expand_errors(input_lines: List[str], fname: str) -> List[str]: - """Transform comments such as '# E: message' or - '# E:3: message' in input. - - The result is lines like 'fnam:line: error: message'. - """ - output_lines = [] - for lnum, line in enumerate(input_lines): - # The first in the split things isn't a comment - for possible_err_comment in line.split(' # ')[1:]: - m = re.search( - r'^([ENW]):((?P\d+):)? (?P.*)$', - possible_err_comment.strip()) - if m: - if m.group(1) == 'E': - severity = 'error' - elif m.group(1) == 'N': - severity = 'note' - elif m.group(1) == 'W': - severity = 'warning' - col = m.group('col') - output_lines.append(build_output_line(fname, lnum, severity, - message=m.group("message"), - col=col)) - return output_lines - - -def get_func_first_lnum(attr: Callable[..., None]) -> Optional[Tuple[int, List[str]]]: - lines, _ = inspect.getsourcelines(attr) - for lnum, line in enumerate(lines): - no_space_line = line.strip() - if f'def {attr.__name__}' in no_space_line: - return lnum, lines[lnum + 1:] - raise ValueError(f'No line "def {attr.__name__}" found') diff --git a/test/pytest_plugin.py b/test/pytest_plugin.py deleted file mode 100644 index 8c39572..0000000 --- a/test/pytest_plugin.py +++ /dev/null @@ -1,303 +0,0 @@ -import dataclasses -import inspect -import os -import sys -import tempfile -import textwrap -from contextlib import contextmanager -from pathlib import Path -from typing import Iterator, Any, Optional, cast, List, Type, Callable, Dict - -import pytest -from _pytest._code.code import ReprFileLocation, ReprEntry, ExceptionInfo -from decorator import decorate -from mypy import api as mypy_api - -from test import vistir -from test.helpers import assert_string_arrays_equal, TypecheckAssertionError, expand_errors, get_func_first_lnum - - -def reveal_type(obj: Any) -> None: - # noop method, just to get rid of "method is not resolved" errors - pass - - -def output(output_lines: str): - def decor(func: Callable[..., None]): - func.out = output_lines.strip() - - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - return decorate(func, wrapper) - - return decor - - -def get_class_that_defined_method(meth) -> Type['MypyTypecheckTestCase']: - if inspect.ismethod(meth): - for cls in inspect.getmro(meth.__self__.__class__): - if cls.__dict__.get(meth.__name__) is meth: - return cls - meth = meth.__func__ # fallback to __qualname__ parsing - if inspect.isfunction(meth): - cls = getattr(inspect.getmodule(meth), - meth.__qualname__.split('.', 1)[0].rsplit('.', 1)[0]) - if issubclass(cls, MypyTypecheckTestCase): - return cls - return getattr(meth, '__objclass__', None) # handle special descriptor objects - - -def file(filename: str, make_parent_packages=False): - def decor(func: Callable[..., None]): - func.filename = filename - func.make_parent_packages = make_parent_packages - return func - - return decor - - -def env(**environ): - def decor(func: Callable[..., None]): - func.env = environ - return func - - return decor - - -@dataclasses.dataclass -class CreateFile: - sources: str - make_parent_packages: bool = False - - -class MypyTypecheckMeta(type): - def __new__(mcs, name, bases, attrs): - cls = super().__new__(mcs, name, bases, attrs) - cls.files: Dict[str, CreateFile] = {} - - for name, attr in attrs.items(): - if inspect.isfunction(attr): - filename = getattr(attr, 'filename', None) - if not filename: - continue - make_parent_packages = getattr(attr, 'make_parent_packages', False) - sources = textwrap.dedent(''.join(get_func_first_lnum(attr)[1])) - if sources.strip() == 'pass': - sources = '' - cls.files[filename] = CreateFile(sources, make_parent_packages) - - return cls - - -class MypyTypecheckTestCase(metaclass=MypyTypecheckMeta): - files = None - - def ini_file(self) -> str: - return """ -[mypy] - """ - - def _get_ini_file_contents(self) -> Optional[str]: - raw_ini_file = self.ini_file() - if not raw_ini_file: - return raw_ini_file - return raw_ini_file.strip() + '\n' - - -class TraceLastReprEntry(ReprEntry): - def toterminal(self, tw): - self.reprfileloc.toterminal(tw) - for line in self.lines: - red = line.startswith("E ") - tw.line(line, bold=True, red=red) - return - - -def fname_to_module(fpath: Path, root_path: Path) -> Optional[str]: - try: - relpath = fpath.relative_to(root_path).with_suffix('') - return str(relpath).replace(os.sep, '.') - except ValueError: - return None - - -class MypyTypecheckItem(pytest.Item): - root_directory = '/tmp' - - def __init__(self, - name: str, - parent: 'MypyTestsCollector', - klass: Type[MypyTypecheckTestCase], - source_code: str, - first_lineno: int, - ini_file_contents: Optional[str] = None, - expected_output_lines: Optional[List[str]] = None, - files: Optional[Dict[str, CreateFile]] = None, - custom_environment: Optional[Dict[str, Any]] = None): - super().__init__(name=name, parent=parent) - self.klass = klass - self.source_code = source_code - self.first_lineno = first_lineno - self.ini_file_contents = ini_file_contents - self.expected_output_lines = expected_output_lines - self.files = files - self.custom_environment = custom_environment - - @contextmanager - def temp_directory(self) -> Path: - with tempfile.TemporaryDirectory(prefix='mypy-pytest-', - dir=self.root_directory) as tmpdir_name: - yield Path(self.root_directory) / tmpdir_name - - def runtest(self): - with self.temp_directory() as tmpdir_path: - if not self.source_code: - return - - if self.ini_file_contents: - mypy_ini_fpath = tmpdir_path / 'mypy.ini' - mypy_ini_fpath.write_text(self.ini_file_contents) - - test_specific_modules = [] - for fname, create_file in self.files.items(): - fpath = tmpdir_path / fname - if create_file.make_parent_packages: - fpath.parent.mkdir(parents=True, exist_ok=True) - for parent in fpath.parents: - try: - parent.relative_to(tmpdir_path) - if parent != tmpdir_path: - parent_init_file = parent / '__init__.py' - parent_init_file.write_text('') - test_specific_modules.append(fname_to_module(parent, - root_path=tmpdir_path)) - except ValueError: - break - - fpath.write_text(create_file.sources) - test_specific_modules.append(fname_to_module(fpath, - root_path=tmpdir_path)) - - with vistir.temp_environ(), vistir.temp_path(): - for key, val in (self.custom_environment or {}).items(): - os.environ[key] = val - sys.path.insert(0, str(tmpdir_path)) - - mypy_cmd_options = self.prepare_mypy_cmd_options(config_file_path=mypy_ini_fpath) - main_fpath = tmpdir_path / 'main.py' - main_fpath.write_text(self.source_code) - mypy_cmd_options.append(str(main_fpath)) - - stdout, stderr, returncode = mypy_api.run(mypy_cmd_options) - output_lines = [] - for line in (stdout + stderr).splitlines(): - if ':' not in line: - continue - out_fpath, res_line = line.split(':', 1) - line = os.path.relpath(out_fpath, start=tmpdir_path) + ':' + res_line - output_lines.append(line.strip().replace('.py', '')) - - for module in test_specific_modules: - parts = module.split('.') - for i in range(len(parts)): - parent_module = '.'.join(parts[:i + 1]) - if parent_module in sys.modules: - del sys.modules[parent_module] - - assert_string_arrays_equal(expected=self.expected_output_lines, - actual=output_lines) - - def prepare_mypy_cmd_options(self, config_file_path: Path) -> List[str]: - mypy_cmd_options = [ - '--raise-exceptions', - '--no-silence-site-packages' - ] - python_version = '.'.join([str(part) for part in sys.version_info[:2]]) - mypy_cmd_options.append(f'--python-version={python_version}') - if self.ini_file_contents: - mypy_cmd_options.append(f'--config-file={config_file_path}') - return mypy_cmd_options - - def repr_failure(self, excinfo: ExceptionInfo) -> str: - if excinfo.errisinstance(SystemExit): - # We assume that before doing exit() (which raises SystemExit) we've printed - # enough context about what happened so that a stack trace is not useful. - # In particular, uncaught exceptions during semantic analysis or type checking - # call exit() and they already print out a stack trace. - return excinfo.exconly(tryshort=True) - elif excinfo.errisinstance(TypecheckAssertionError): - # with traceback removed - exception_repr = excinfo.getrepr(style='short') - exception_repr.reprcrash.message = '' - repr_file_location = ReprFileLocation(path=inspect.getfile(self.klass), - lineno=self.first_lineno + excinfo.value.lineno, - message='') - repr_tb_entry = TraceLastReprEntry(filelocrepr=repr_file_location, - lines=exception_repr.reprtraceback.reprentries[-1].lines[1:], - style='short', - reprlocals=None, - reprfuncargs=None) - exception_repr.reprtraceback.reprentries = [repr_tb_entry] - return exception_repr - else: - return super().repr_failure(excinfo, style='native') - - def reportinfo(self): - return self.fspath, None, get_class_qualname(self.klass) + '::' + self.name - - -def get_class_qualname(klass: type) -> str: - return klass.__module__ + '.' + klass.__name__ - - -def extract_test_output(attr: Callable[..., None]) -> List[str]: - out_data: str = getattr(attr, 'out', None) - out_lines = [] - if out_data: - for line in out_data.strip().split('\n'): - if line: - line = line.strip() - out_lines.append(line) - return out_lines - - -class MypyTestsCollector(pytest.Class): - def get_ini_file_contents(self, contents: str) -> str: - return contents.strip() + '\n' - - def collect(self) -> Iterator[pytest.Item]: - current_testcase = cast(MypyTypecheckTestCase, self.obj()) - ini_file_contents = self.get_ini_file_contents(current_testcase.ini_file()) - for attr_name in dir(current_testcase): - if attr_name.startswith('test_'): - attr = getattr(self.obj, attr_name) - if inspect.isfunction(attr): - first_line_lnum, source_lines = get_func_first_lnum(attr) - func_first_line_in_file = inspect.getsourcelines(attr)[1] + first_line_lnum - - output_from_decorator = extract_test_output(attr) - output_from_comments = expand_errors(source_lines, 'main') - custom_env = getattr(attr, 'env', None) - main_source_code = textwrap.dedent(''.join(source_lines)) - yield MypyTypecheckItem(name=attr_name, - parent=self, - klass=current_testcase.__class__, - source_code=main_source_code, - first_lineno=func_first_line_in_file, - ini_file_contents=ini_file_contents, - expected_output_lines=output_from_comments - + output_from_decorator, - files=current_testcase.__class__.files, - custom_environment=custom_env) - - -def pytest_pycollect_makeitem(collector: Any, name: str, obj: Any) -> Optional[MypyTestsCollector]: - # Only classes derived from DataSuite contain test cases, not the DataSuite class itself - if (isinstance(obj, type) - and issubclass(obj, MypyTypecheckTestCase) - and obj is not MypyTypecheckTestCase): - # Non-None result means this obj is a test case. - # The collect method of the returned DataSuiteCollector instance will be called later, - # with self.obj being obj. - return MypyTestsCollector(name, parent=collector) diff --git a/test/test-data/postgres-fields.test b/test/test-data/postgres-fields.test deleted file mode 100644 index b2aeeee..0000000 --- a/test/test-data/postgres-fields.test +++ /dev/null @@ -1,21 +0,0 @@ -[CASE array_field_descriptor_access] -from django.db import models -from django.contrib.postgres.fields import ArrayField - -class User(models.Model): - array = ArrayField(base_field=models.Field()) - -user = User() -reveal_type(user.array) # E: Revealed type is 'builtins.list[Any]' - -[CASE array_field_base_field_parsed_into_generic_typevar] -from django.db import models -from django.contrib.postgres.fields import ArrayField - -class User(models.Model): - members = ArrayField(base_field=models.IntegerField()) - members_as_text = ArrayField(base_field=models.CharField(max_length=255)) - -user = User() -reveal_type(user.members) # E: Revealed type is 'builtins.list[builtins.int*]' -reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list[builtins.str*]' diff --git a/test/testdjango.py b/test/testdjango.py deleted file mode 100644 index 32e4419..0000000 --- a/test/testdjango.py +++ /dev/null @@ -1,60 +0,0 @@ -import os -import sys -from pathlib import Path - -from mypy import api -from mypy.test.config import test_temp_dir -from mypy.test.data import DataSuite, DataDrivenTestCase -from mypy.test.helpers import assert_string_arrays_equal - -ROOT_DIR = Path(__file__).parent.parent -TEST_DATA_DIR = ROOT_DIR / 'test' / 'test-data' -MYPY_INI_PATH = ROOT_DIR / 'test' / 'plugins.ini' - - -class DjangoTestSuite(DataSuite): - files = [ - # 'check-objects-queryset.test', - # 'check-model-fields.test', - # 'check-postgres-fields.test', - # 'check-model-relations.test', - # 'check-parse-settings.test', - # 'check-to-attr-as-string-one-to-one-field.test', - 'check-to-attr-as-string-foreign-key.test', - # 'check-foreign-key-as-string-creates-underscore-id-attr.test' - ] - data_prefix = str(TEST_DATA_DIR) - - def run_case(self, testcase: DataDrivenTestCase) -> None: - assert testcase.old_cwd is not None, "test was not properly set up" - - mypy_cmdline = [ - '--show-traceback', - '--no-silence-site-packages', - '--config-file={}'.format(MYPY_INI_PATH) - ] - mypy_cmdline.append('--python-version={}'.format('.'.join(map(str, - sys.version_info[:2])))) - - program_path = os.path.join(test_temp_dir, 'main.py') - mypy_cmdline.append(program_path) - - with open(program_path, 'w') as file: - for s in testcase.input: - file.write('{}\n'.format(s)) - - output = [] - # Type check the program. - out, err, returncode = api.run(mypy_cmdline) - # split lines, remove newlines, and remove directory of test case - for line in (out + err).splitlines(): - if line.startswith(test_temp_dir + os.sep): - output.append(line[len(test_temp_dir + os.sep):].rstrip("\r\n").replace('.py', '')) - else: - output.append(line.rstrip("\r\n")) - # Remove temp file. - os.remove(program_path) - - assert_string_arrays_equal(testcase.output, output, - 'Invalid output ({}, line {})'.format( - testcase.file, testcase.line)) diff --git a/test/vistir.py b/test/vistir.py deleted file mode 100644 index 4023716..0000000 --- a/test/vistir.py +++ /dev/null @@ -1,43 +0,0 @@ -# Borrowed from Pew. -# See https://github.com/berdario/pew/blob/master/pew/_utils.py#L82 -import os -import sys -from pathlib import Path - -from decorator import contextmanager - - -@contextmanager -def temp_environ(): - """Allow the ability to set os.environ temporarily""" - environ = dict(os.environ) - try: - yield - finally: - os.environ.clear() - os.environ.update(environ) - - -@contextmanager -def temp_path(): - """A context manager which allows the ability to set sys.path temporarily""" - path = [p for p in sys.path] - try: - yield - finally: - sys.path = [p for p in path] - - -@contextmanager -def cd(path): - """Context manager to temporarily change working directories""" - if not path: - return - prev_cwd = Path.cwd().as_posix() - if isinstance(path, Path): - path = path.as_posix() - os.chdir(str(path)) - try: - yield - finally: - os.chdir(prev_cwd)