mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-07 04:34:29 +08:00
cleanups, fix settings
This commit is contained in:
@@ -16,7 +16,7 @@ class ArrayField(CheckFieldDefaultMixin, Field, Generic[_T]):
|
||||
from_db_value: Any = ...
|
||||
|
||||
def __init__(
|
||||
self, base_field: Field, size: None = ..., **kwargs: Any
|
||||
self, base_field: _T, size: None = ..., **kwargs: Any
|
||||
) -> None: ...
|
||||
@property
|
||||
def model(self): ...
|
||||
|
||||
1
external/mypy
vendored
1
external/mypy
vendored
Submodule external/mypy deleted from b790539825
@@ -1,9 +1,9 @@
|
||||
import typing
|
||||
from typing import Dict, Optional, NamedTuple
|
||||
|
||||
from mypy.nodes import SymbolTableNode, Var, Expression, StrExpr, MypyFile, TypeInfo
|
||||
from mypy.nodes import Expression, StrExpr, MypyFile, TypeInfo
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Type, Instance, UnionType, NoneTyp
|
||||
from mypy.types import Type, UnionType, NoneTyp
|
||||
|
||||
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
|
||||
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
|
||||
@@ -11,14 +11,6 @@ FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
|
||||
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
|
||||
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
|
||||
|
||||
|
||||
def create_new_symtable_node(name: str, kind: int, instance: Instance) -> SymbolTableNode:
|
||||
new_var = Var(name, instance)
|
||||
new_var.info = instance.type
|
||||
return SymbolTableNode(kind, new_var,
|
||||
plugin_generated=True)
|
||||
|
||||
|
||||
Argument = NamedTuple('Argument', fields=[
|
||||
('arg', Expression),
|
||||
('arg_type', Type)
|
||||
|
||||
@@ -1,81 +1,24 @@
|
||||
import os
|
||||
from typing import Callable, Optional, cast
|
||||
from typing import Callable, Optional
|
||||
|
||||
from mypy.nodes import AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr
|
||||
from mypy.options import Options
|
||||
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
|
||||
from mypy.semanal import SemanticAnalyzerPass2
|
||||
from mypy.types import Type, Instance
|
||||
from mypy.plugin import Plugin, FunctionContext, ClassDefContext, AnalyzeTypeContext
|
||||
from mypy.types import Type
|
||||
|
||||
from mypy_django_plugin import helpers, monkeypatch
|
||||
from mypy_django_plugin.plugins.meta_inner_class import inject_any_as_base_for_nested_class_meta
|
||||
from mypy_django_plugin.plugins.objects_queryset import set_objects_queryset_to_model_class
|
||||
from mypy_django_plugin.plugins.fields import determine_type_of_array_field, \
|
||||
add_int_id_attribute_if_primary_key_true_is_not_present
|
||||
from mypy_django_plugin.plugins.related_fields import set_fieldname_attrs_for_related_fields, add_new_var_node_to_class, \
|
||||
extract_to_parameter_as_get_ret_type
|
||||
from mypy_django_plugin.plugins.setup_settings import DjangoConfSettingsInitializerHook
|
||||
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
|
||||
from mypy_django_plugin.plugins.models import process_model_class
|
||||
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field
|
||||
from mypy_django_plugin.plugins.settings import DjangoConfSettingsInitializerHook
|
||||
|
||||
|
||||
base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
|
||||
|
||||
|
||||
def add_related_managers_from_referred_foreign_keys_to_model(ctx: ClassDefContext) -> None:
|
||||
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||
for stmt in ctx.cls.defs.body:
|
||||
if not isinstance(stmt, AssignmentStmt):
|
||||
continue
|
||||
if len(stmt.lvalues) > 1:
|
||||
# not supported yet
|
||||
continue
|
||||
rvalue = stmt.rvalue
|
||||
if not isinstance(rvalue, CallExpr):
|
||||
continue
|
||||
if (not isinstance(rvalue.callee, MemberExpr)
|
||||
or not rvalue.callee.fullname in {helpers.FOREIGN_KEY_FULLNAME,
|
||||
helpers.ONETOONE_FIELD_FULLNAME}):
|
||||
continue
|
||||
if 'related_name' not in rvalue.arg_names:
|
||||
# positional related_name is not supported yet
|
||||
continue
|
||||
related_name = rvalue.args[rvalue.arg_names.index('related_name')].value
|
||||
|
||||
if 'to' in rvalue.arg_names:
|
||||
expr = rvalue.args[rvalue.arg_names.index('to')]
|
||||
else:
|
||||
# first positional argument
|
||||
expr = rvalue.args[0]
|
||||
|
||||
if isinstance(expr, StrExpr):
|
||||
model_typeinfo = helpers.get_model_type_from_string(expr,
|
||||
all_modules=api.modules)
|
||||
if model_typeinfo is None:
|
||||
continue
|
||||
elif isinstance(expr, NameExpr):
|
||||
model_typeinfo = expr.node
|
||||
else:
|
||||
continue
|
||||
|
||||
if rvalue.callee.fullname == helpers.FOREIGN_KEY_FULLNAME:
|
||||
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME,
|
||||
args=[Instance(ctx.cls.info, [])])
|
||||
else:
|
||||
typ = Instance(ctx.cls.info, [])
|
||||
|
||||
if typ is None:
|
||||
continue
|
||||
add_new_var_node_to_class(model_typeinfo, related_name, typ)
|
||||
|
||||
|
||||
class TransformModelClassHook(object):
|
||||
def __call__(self, ctx: ClassDefContext) -> None:
|
||||
base_model_classes.add(ctx.cls.fullname)
|
||||
|
||||
set_fieldname_attrs_for_related_fields(ctx)
|
||||
set_objects_queryset_to_model_class(ctx)
|
||||
inject_any_as_base_for_nested_class_meta(ctx)
|
||||
add_related_managers_from_referred_foreign_keys_to_model(ctx)
|
||||
add_int_id_attribute_if_primary_key_true_is_not_present(ctx)
|
||||
process_model_class(ctx)
|
||||
|
||||
|
||||
class DjangoPlugin(Plugin):
|
||||
@@ -89,7 +32,6 @@ class DjangoPlugin(Plugin):
|
||||
if self.django_settings:
|
||||
monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(self.django_settings)
|
||||
monkeypatch.inject_dependencies(self.django_settings)
|
||||
# monkeypatch.process_settings_before_dependants(self.django_settings)
|
||||
else:
|
||||
monkeypatch.restore_original_load_graph()
|
||||
monkeypatch.restore_original_dependencies_handling()
|
||||
@@ -98,7 +40,7 @@ class DjangoPlugin(Plugin):
|
||||
) -> Optional[Callable[[FunctionContext], Type]]:
|
||||
if fullname in {helpers.FOREIGN_KEY_FULLNAME,
|
||||
helpers.ONETOONE_FIELD_FULLNAME}:
|
||||
return extract_to_parameter_as_get_ret_type
|
||||
return extract_to_parameter_as_get_ret_type_for_related_field
|
||||
|
||||
# if fullname == helpers.ONETOONE_FIELD_FULLNAME:
|
||||
# return OneToOneFieldHook(settings=self.django_settings)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Optional, List, Sequence, NamedTuple, Tuple
|
||||
|
||||
from mypy import checkexpr
|
||||
from mypy.argmap import map_actuals_to_formals
|
||||
from mypy.checkexpr import map_actuals_to_formals
|
||||
from mypy.checkmember import analyze_member_access
|
||||
from mypy.expandtype import freshen_function_type_vars
|
||||
from mypy.messages import MessageBuilder
|
||||
@@ -68,7 +68,6 @@ class PatchedExpressionChecker(checkexpr.ExpressionChecker):
|
||||
on which the method is being called
|
||||
"""
|
||||
arg_messages = arg_messages or self.msg
|
||||
|
||||
if isinstance(callee, CallableType):
|
||||
if callable_name is None and callee.name:
|
||||
callable_name = callee.name
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
from typing import Iterator, List, cast
|
||||
|
||||
from mypy.nodes import ClassDef, AssignmentStmt, CallExpr
|
||||
from mypy.plugin import FunctionContext, ClassDefContext
|
||||
from mypy.semanal import SemanticAnalyzerPass2
|
||||
from mypy.types import Type, Instance
|
||||
|
||||
from mypy_django_plugin.plugins.related_fields import add_new_var_node_to_class
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Type
|
||||
|
||||
|
||||
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
|
||||
@@ -15,27 +9,3 @@ def determine_type_of_array_field(ctx: FunctionContext) -> Type:
|
||||
base_field_arg_type = ctx.arg_types[ctx.arg_names.index('base_field')][0]
|
||||
return ctx.api.named_generic_type(ctx.context.callee.fullname,
|
||||
args=[base_field_arg_type.type.names['__get__'].type.ret_type])
|
||||
|
||||
|
||||
def get_assignments(klass: ClassDef) -> List[AssignmentStmt]:
|
||||
stmts = []
|
||||
for stmt in klass.defs.body:
|
||||
if not isinstance(stmt, AssignmentStmt):
|
||||
continue
|
||||
if len(stmt.lvalues) > 1:
|
||||
# not supported yet
|
||||
continue
|
||||
stmts.append(stmt)
|
||||
return stmts
|
||||
|
||||
|
||||
def add_int_id_attribute_if_primary_key_true_is_not_present(ctx: ClassDefContext) -> None:
|
||||
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||
for stmt in get_assignments(ctx.cls):
|
||||
if (isinstance(stmt.rvalue, CallExpr)
|
||||
and 'primary_key' in stmt.rvalue.arg_names
|
||||
and api.parse_bool(stmt.rvalue.args[stmt.rvalue.arg_names.index('primary_key')])):
|
||||
break
|
||||
else:
|
||||
add_new_var_node_to_class(ctx.cls.info, 'id', api.builtin_type('builtins.int'))
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.plugin import ClassDefContext
|
||||
|
||||
|
||||
def inject_any_as_base_for_nested_class_meta(ctx: ClassDefContext) -> None:
|
||||
if 'Meta' not in ctx.cls.info.names:
|
||||
return None
|
||||
sym = ctx.cls.info.names['Meta']
|
||||
if not isinstance(sym.node, TypeInfo):
|
||||
return None
|
||||
|
||||
sym.node.fallback_to_any = True
|
||||
180
mypy_django_plugin/plugins/models.py
Normal file
180
mypy_django_plugin/plugins/models.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from typing import cast, Iterator, Tuple, Optional
|
||||
|
||||
from mypy.nodes import ClassDef, AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr, MDEF, TypeInfo, Var, SymbolTableNode, \
|
||||
Lvalue, Expression, Statement
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.semanal import SemanticAnalyzerPass2
|
||||
from mypy.types import Instance
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def add_new_var_node_to_class(class_type: TypeInfo, name: str, typ: Instance) -> None:
|
||||
var = Var(name=name, type=typ)
|
||||
var.info = typ.type
|
||||
var._fullname = class_type.fullname() + '.' + name
|
||||
var.is_inferred = True
|
||||
var.is_initialized_in_class = True
|
||||
class_type.names[name] = SymbolTableNode(MDEF, var)
|
||||
|
||||
|
||||
def iter_over_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, Expression]]:
|
||||
for stmt in klass.defs.body:
|
||||
if not isinstance(stmt, AssignmentStmt):
|
||||
continue
|
||||
if len(stmt.lvalues) > 1:
|
||||
# not supported yet
|
||||
continue
|
||||
yield stmt.lvalues[0], stmt.rvalue
|
||||
|
||||
|
||||
def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]:
|
||||
for lvalue, rvalue in iter_over_assignments(klass):
|
||||
if not isinstance(rvalue, CallExpr):
|
||||
continue
|
||||
yield lvalue, rvalue
|
||||
|
||||
|
||||
def iter_over_one_to_n_related_fields(klass: ClassDef, api: SemanticAnalyzerPass2) -> Iterator[Tuple[NameExpr, CallExpr]]:
|
||||
for lvalue, rvalue in iter_call_assignments(klass):
|
||||
if (isinstance(lvalue, NameExpr)
|
||||
and isinstance(rvalue.callee, MemberExpr)):
|
||||
if rvalue.callee.fullname in {helpers.FOREIGN_KEY_FULLNAME,
|
||||
helpers.ONETOONE_FIELD_FULLNAME}:
|
||||
yield lvalue, rvalue
|
||||
|
||||
|
||||
def get_nested_meta_class(model_type: TypeInfo) -> Optional[TypeInfo]:
|
||||
metaclass_sym = model_type.names.get('Meta')
|
||||
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
|
||||
return metaclass_sym.node
|
||||
return None
|
||||
|
||||
|
||||
def is_abstract_model(ctx: ClassDefContext) -> bool:
|
||||
meta_node = get_nested_meta_class(ctx.cls.info)
|
||||
if meta_node is None:
|
||||
return False
|
||||
|
||||
for lvalue, rvalue in iter_over_assignments(meta_node.defn):
|
||||
if isinstance(lvalue, NameExpr) and lvalue.name == 'abstract':
|
||||
is_abstract = ctx.api.parse_bool(rvalue)
|
||||
if is_abstract:
|
||||
# abstract model do not need 'objects' queryset
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
|
||||
api = ctx.api
|
||||
for lvalue, rvalue in iter_over_one_to_n_related_fields(ctx.cls, api):
|
||||
property_name = lvalue.name + '_id'
|
||||
add_new_var_node_to_class(ctx.cls.info, property_name,
|
||||
typ=api.named_type('__builtins__.int'))
|
||||
|
||||
|
||||
def add_int_id_attribute_if_primary_key_true_is_not_present(ctx: ClassDefContext) -> None:
|
||||
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||
if is_abstract_model(ctx):
|
||||
return None
|
||||
|
||||
for _, rvalue in iter_call_assignments(ctx.cls):
|
||||
if ('primary_key' in rvalue.arg_names and
|
||||
api.parse_bool(rvalue.args[rvalue.arg_names.index('primary_key')])):
|
||||
break
|
||||
else:
|
||||
add_new_var_node_to_class(ctx.cls.info, 'id', api.builtin_type('builtins.int'))
|
||||
|
||||
|
||||
def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None:
|
||||
# search over mro
|
||||
objects_sym = ctx.cls.info.get('objects')
|
||||
if objects_sym is not None:
|
||||
return None
|
||||
|
||||
# only direct Meta class
|
||||
if is_abstract_model(ctx):
|
||||
# abstract model do not need 'objects' queryset
|
||||
return None
|
||||
|
||||
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME,
|
||||
args=[Instance(ctx.cls.info, [])])
|
||||
if not typ:
|
||||
return None
|
||||
add_new_var_node_to_class(ctx.cls.info, 'objects', typ=typ)
|
||||
|
||||
|
||||
def inject_any_as_base_for_nested_class_meta(ctx: ClassDefContext) -> None:
|
||||
meta_node = get_nested_meta_class(ctx.cls.info)
|
||||
if meta_node is None:
|
||||
return None
|
||||
meta_node.fallback_to_any = True
|
||||
|
||||
|
||||
def is_model_defn(defn: Statement, api: SemanticAnalyzerPass2) -> bool:
|
||||
if not isinstance(defn, ClassDef):
|
||||
return False
|
||||
|
||||
for base_type_expr in defn.base_type_exprs:
|
||||
# api.accept(base_type_expr)
|
||||
fullname = getattr(base_type_expr, 'fullname', None)
|
||||
if fullname == helpers.MODEL_CLASS_FULLNAME:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def iter_over_models(ctx: ClassDefContext) -> Iterator[ClassDef]:
|
||||
for module_name, module_file in ctx.api.modules.items():
|
||||
for defn in module_file.defs:
|
||||
if is_model_defn(defn, api=cast(SemanticAnalyzerPass2, ctx.api)):
|
||||
yield defn
|
||||
|
||||
|
||||
def extract_to_value_or_none(field_expr: CallExpr, ctx: ClassDefContext) -> Optional[TypeInfo]:
|
||||
if 'to' in field_expr.arg_names:
|
||||
ref_expr = field_expr.args[field_expr.arg_names.index('to')]
|
||||
else:
|
||||
# first positional argument
|
||||
ref_expr = field_expr.args[0]
|
||||
|
||||
if isinstance(ref_expr, StrExpr):
|
||||
model_typeinfo = helpers.get_model_type_from_string(ref_expr,
|
||||
all_modules=ctx.api.modules)
|
||||
return model_typeinfo
|
||||
elif isinstance(ref_expr, NameExpr):
|
||||
return ref_expr.node
|
||||
|
||||
|
||||
def get_related_field_type(rvalue: CallExpr, api: SemanticAnalyzerPass2,
|
||||
related_model_typ: TypeInfo) -> Optional[Instance]:
|
||||
if rvalue.callee.fullname == helpers.FOREIGN_KEY_FULLNAME:
|
||||
return api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME,
|
||||
args=[Instance(related_model_typ, [])])
|
||||
else:
|
||||
return Instance(related_model_typ, [])
|
||||
|
||||
|
||||
def add_related_managers(ctx: ClassDefContext) -> None:
|
||||
for model_defn in iter_over_models(ctx):
|
||||
for _, rvalue in iter_over_one_to_n_related_fields(model_defn, ctx.api):
|
||||
if 'related_name' not in rvalue.arg_names:
|
||||
# positional related_name is not supported yet
|
||||
return
|
||||
related_name = rvalue.args[rvalue.arg_names.index('related_name')].value
|
||||
ref_to_typ = extract_to_value_or_none(rvalue, ctx)
|
||||
if ref_to_typ is not None:
|
||||
if ref_to_typ.fullname() == ctx.cls.info.fullname():
|
||||
typ = get_related_field_type(rvalue, ctx.api,
|
||||
related_model_typ=model_defn.info)
|
||||
if typ is None:
|
||||
return
|
||||
add_new_var_node_to_class(ctx.cls.info, related_name, typ)
|
||||
|
||||
|
||||
def process_model_class(ctx: ClassDefContext) -> None:
|
||||
# add_related_managers(ctx)
|
||||
inject_any_as_base_for_nested_class_meta(ctx)
|
||||
set_fieldname_attrs_for_related_fields(ctx)
|
||||
add_int_id_attribute_if_primary_key_true_is_not_present(ctx)
|
||||
set_objects_queryset_to_model_class(ctx)
|
||||
@@ -1,36 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from mypy.nodes import MDEF, AssignmentStmt
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.semanal import SemanticAnalyzerPass2
|
||||
from mypy.types import Instance
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None:
|
||||
# search over mro
|
||||
objects_sym = ctx.cls.info.get('objects')
|
||||
if objects_sym is not None:
|
||||
return None
|
||||
|
||||
# only direct Meta class
|
||||
metaclass_sym = ctx.cls.info.names.get('Meta')
|
||||
# skip if abstract
|
||||
if metaclass_sym is not None:
|
||||
for stmt in metaclass_sym.node.defn.defs.body:
|
||||
if (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1
|
||||
and stmt.lvalues[0].name == 'abstract'):
|
||||
is_abstract = ctx.api.parse_bool(stmt.rvalue)
|
||||
if is_abstract:
|
||||
return None
|
||||
|
||||
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME,
|
||||
args=[Instance(ctx.cls.info, [])])
|
||||
if not typ:
|
||||
return None
|
||||
|
||||
ctx.cls.info.names['objects'] = helpers.create_new_symtable_node('objects',
|
||||
kind=MDEF,
|
||||
instance=typ)
|
||||
@@ -1,18 +1,12 @@
|
||||
import typing
|
||||
from typing import Optional, cast
|
||||
|
||||
from django.conf import Settings
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import MDEF, AssignmentStmt, MypyFile, StrExpr, TypeInfo, NameExpr, Var, SymbolTableNode
|
||||
from mypy.plugin import FunctionContext, ClassDefContext
|
||||
from mypy.nodes import StrExpr
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
from mypy_django_plugin.helpers import get_models_file
|
||||
|
||||
|
||||
def extract_related_name_value(ctx: FunctionContext) -> str:
|
||||
return ctx.args[ctx.arg_names.index('related_name')][0].value
|
||||
|
||||
|
||||
def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]):
|
||||
@@ -55,44 +49,9 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
|
||||
return referred_to_type
|
||||
|
||||
|
||||
def add_new_var_node_to_class(class_type: TypeInfo, name: str, typ: Instance) -> None:
|
||||
var = Var(name=name, type=typ)
|
||||
var.info = typ.type
|
||||
var._fullname = class_type.fullname() + '.' + name
|
||||
var.is_inferred = True
|
||||
var.is_initialized_in_class = True
|
||||
class_type.names[name] = SymbolTableNode(MDEF, var)
|
||||
|
||||
|
||||
def extract_to_parameter_as_get_ret_type(ctx: FunctionContext) -> Type:
|
||||
def extract_to_parameter_as_get_ret_type_for_related_field(ctx: FunctionContext) -> Type:
|
||||
referred_to_type = get_valid_to_value_or_none(ctx)
|
||||
if referred_to_type is None:
|
||||
# couldn't extract to= value
|
||||
return fill_typevars_with_any(ctx.default_return_type)
|
||||
return reparametrize_with(ctx.default_return_type, [referred_to_type])
|
||||
|
||||
|
||||
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
|
||||
api = ctx.api
|
||||
for stmt in ctx.cls.defs.body:
|
||||
if not isinstance(stmt, AssignmentStmt):
|
||||
continue
|
||||
if not hasattr(stmt.rvalue, 'callee'):
|
||||
continue
|
||||
if len(stmt.lvalues) > 1:
|
||||
# multiple lvalues not supported for now
|
||||
continue
|
||||
|
||||
expr = stmt.lvalues[0]
|
||||
if not isinstance(expr, NameExpr):
|
||||
continue
|
||||
name = expr.name
|
||||
|
||||
rvalue_callee = stmt.rvalue.callee
|
||||
if rvalue_callee.fullname in {helpers.FOREIGN_KEY_FULLNAME,
|
||||
helpers.ONETOONE_FIELD_FULLNAME}:
|
||||
name += '_id'
|
||||
new_node = helpers.create_new_symtable_node(name,
|
||||
kind=MDEF,
|
||||
instance=api.named_type('__builtins__.int'))
|
||||
ctx.cls.info.names[name] = new_node
|
||||
|
||||
61
mypy_django_plugin/plugins/settings.py
Normal file
61
mypy_django_plugin/plugins/settings.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import cast, List
|
||||
|
||||
from mypy.nodes import Var, Context, SymbolNode, SymbolTableNode
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.semanal import SemanticAnalyzerPass2
|
||||
from mypy.types import Instance, UnionType, NoneTyp, Type
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def get_error_context(node: SymbolNode) -> Context:
|
||||
context = Context()
|
||||
context.set_line(node)
|
||||
return context
|
||||
|
||||
|
||||
def filter_out_nones(typ: UnionType) -> List[Type]:
|
||||
return [item for item in typ.items if not isinstance(item, NoneTyp)]
|
||||
|
||||
|
||||
def copy_sym_of_instance(sym: SymbolTableNode) -> SymbolTableNode:
|
||||
copied = sym.copy()
|
||||
copied.node.info = sym.type.type
|
||||
return copied
|
||||
|
||||
|
||||
def add_settings_to_django_conf_object(ctx: ClassDefContext,
|
||||
settings_module: str) -> None:
|
||||
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||
if settings_module not in api.modules:
|
||||
return None
|
||||
|
||||
settings_file = api.modules[settings_module]
|
||||
for name, sym in settings_file.names.items():
|
||||
if name.isupper() and isinstance(sym.node, Var):
|
||||
if isinstance(sym.type, Instance):
|
||||
copied = sym.copy()
|
||||
copied.node.info = sym.type.type
|
||||
ctx.cls.info.names[name] = copied
|
||||
|
||||
elif isinstance(sym.type, UnionType):
|
||||
instances = filter_out_nones(sym.type)
|
||||
if len(instances) > 1:
|
||||
# plain unions not supported yet
|
||||
continue
|
||||
typ = instances[0]
|
||||
if isinstance(typ, Instance):
|
||||
copied = sym.copy()
|
||||
copied.node.info = typ.type
|
||||
ctx.cls.info.names[name] = copied
|
||||
|
||||
|
||||
class DjangoConfSettingsInitializerHook(object):
|
||||
def __init__(self, settings_module: str):
|
||||
self.settings_module = settings_module
|
||||
|
||||
def __call__(self, ctx: ClassDefContext) -> None:
|
||||
if not self.settings_module:
|
||||
return
|
||||
|
||||
add_settings_to_django_conf_object(ctx, self.settings_module)
|
||||
@@ -1,42 +0,0 @@
|
||||
from typing import Optional, Any, cast
|
||||
|
||||
from mypy.nodes import Var, Context, GDEF
|
||||
from mypy.options import Options
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.semanal import SemanticAnalyzerPass2
|
||||
from mypy.types import Instance
|
||||
|
||||
|
||||
def add_settings_to_django_conf_object(ctx: ClassDefContext,
|
||||
settings_module: str) -> Optional[Any]:
|
||||
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||
if settings_module not in api.modules:
|
||||
return None
|
||||
|
||||
settings_file = api.modules[settings_module]
|
||||
for name, sym in settings_file.names.items():
|
||||
if name.isupper():
|
||||
if not isinstance(sym.node, Var) or not isinstance(sym.type, Instance):
|
||||
error_context = Context()
|
||||
error_context.set_line(sym.node)
|
||||
api.msg.fail("Need type annotation for '{}'".format(sym.node.name()),
|
||||
context=error_context,
|
||||
file=settings_file.path,
|
||||
origin=Context())
|
||||
continue
|
||||
|
||||
sym_copy = sym.copy()
|
||||
sym_copy.node.info = sym_copy.type.type
|
||||
sym_copy.kind = GDEF
|
||||
ctx.cls.info.names[name] = sym_copy
|
||||
|
||||
|
||||
class DjangoConfSettingsInitializerHook(object):
|
||||
def __init__(self, settings_module: str):
|
||||
self.settings_module = settings_module
|
||||
|
||||
def __call__(self, ctx: ClassDefContext) -> None:
|
||||
if not self.settings_module:
|
||||
return
|
||||
|
||||
add_settings_to_django_conf_object(ctx, self.settings_module)
|
||||
@@ -4,5 +4,6 @@ testpaths = ./test
|
||||
python_files = test*.py
|
||||
addopts =
|
||||
--tb=native
|
||||
--ignore=./external
|
||||
--mypy-ini-file=./test/plugins.ini
|
||||
--mypy-ini-file=./test-data/plugins.ini
|
||||
-s
|
||||
-v
|
||||
8
setup.py
8
setup.py
@@ -18,7 +18,9 @@ setup(
|
||||
author_email="maxim.kurnikov@gmail.com",
|
||||
version="0.1.0",
|
||||
license='BSD',
|
||||
install_requires=['Django>=2.1.1'],
|
||||
packages=['mypy_django_plugin']
|
||||
# package_data=find_stubs('django-stubs')
|
||||
install_requires=[
|
||||
'Django>=2.1.1',
|
||||
'mypy'
|
||||
],
|
||||
packages=['mypy_django_plugin'],
|
||||
)
|
||||
|
||||
@@ -1,3 +1,25 @@
|
||||
[CASE array_field_descriptor_access]
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
||||
class User(models.Model):
|
||||
array = ArrayField(base_field=models.Field())
|
||||
|
||||
user = User()
|
||||
reveal_type(user.array) # E: Revealed type is 'builtins.list[Any]'
|
||||
|
||||
[CASE array_field_base_field_parsed_into_generic_typevar]
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
||||
class User(models.Model):
|
||||
members = ArrayField(base_field=models.IntegerField())
|
||||
members_as_text = ArrayField(base_field=models.CharField(max_length=255))
|
||||
|
||||
user = User()
|
||||
reveal_type(user.members) # E: Revealed type is 'builtins.list[builtins.int*]'
|
||||
reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list[builtins.str*]'
|
||||
|
||||
[CASE test_model_fields_classes_present_as_primitives]
|
||||
from django.db import models
|
||||
|
||||
@@ -105,7 +105,12 @@ class Profile(models.Model):
|
||||
from django.db import models
|
||||
from myapp.models import App
|
||||
class View(models.Model):
|
||||
app = models.ForeignKey(to='myapp.App', related_name='views', on_delete=models.CASCADE)
|
||||
app = models.ForeignKey(to=App, related_name='views', on_delete=models.CASCADE)
|
||||
|
||||
reveal_type(View().app.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]'
|
||||
reveal_type(View().app.unknown) # E: Revealed type is 'Any'
|
||||
[out]
|
||||
main:7: error: "App" has no attribute "unknown"
|
||||
|
||||
[file myapp/__init__.py]
|
||||
[file myapp/models.py]
|
||||
@@ -113,3 +118,20 @@ from django.db import models
|
||||
class App(models.Model):
|
||||
def method(self) -> None:
|
||||
reveal_type(self.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]'
|
||||
|
||||
[case models_related_managers_work_with_direct_model_inheritance_and_with_inheritance_from_other_model]
|
||||
from django.db.models import Model
|
||||
from django.db import models
|
||||
|
||||
class App(Model):
|
||||
pass
|
||||
|
||||
class View(Model):
|
||||
app = models.ForeignKey(to=App, on_delete=models.CASCADE, related_name='views')
|
||||
|
||||
class View2(View):
|
||||
app = models.ForeignKey(to=App, on_delete=models.CASCADE, related_name='views2')
|
||||
|
||||
reveal_type(App().views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]'
|
||||
reveal_type(App().views2) # E: Revealed type is 'django.db.models.query.QuerySet[main.View2]'
|
||||
[out]
|
||||
@@ -11,9 +11,7 @@ SECRET_KEY = 112233
|
||||
ROOT_DIR = '/etc'
|
||||
NUMBERS = ['one', 'two']
|
||||
DICT = {} # type: ignore
|
||||
|
||||
from django.utils.functional import LazyObject
|
||||
|
||||
OBJ = LazyObject()
|
||||
|
||||
[CASE test_settings_could_be_defined_in_different_module_and_imported_with_star]
|
||||
@@ -36,18 +34,18 @@ ROOT_DIR = Path(__file__)
|
||||
|
||||
[CASE test_circular_dependency_in_settings]
|
||||
from django.conf import settings
|
||||
|
||||
class Class:
|
||||
pass
|
||||
|
||||
reveal_type(settings.MYSETTING) # E: Revealed type is 'builtins.int'
|
||||
reveal_type(settings.REGISTRY) # E: Revealed type is 'Any'
|
||||
reveal_type(settings.REGISTRY) # E: Revealed type is 'Union[main.Class, None]'
|
||||
|
||||
[env DJANGO_SETTINGS_MODULE=mysettings]
|
||||
[file mysettings.py]
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .main import Class
|
||||
from main import Class
|
||||
|
||||
MYSETTING = 1122
|
||||
REGISTRY: Optional['Class'] = None
|
||||
250
test/data.py
250
test/data.py
@@ -1,250 +0,0 @@
|
||||
import os
|
||||
import posixpath
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, Optional, Iterator, Dict, List, Tuple, Set
|
||||
|
||||
import pytest
|
||||
from mypy.test.config import test_temp_dir
|
||||
from mypy.test.data import DataDrivenTestCase, DataSuite, add_test_name_suffix, parse_test_data, \
|
||||
expand_errors, expand_variables, fix_win_path
|
||||
|
||||
|
||||
def parse_test_case(case: 'DataDrivenTestCase') -> None:
|
||||
"""Parse and prepare a single case from suite with test case descriptions.
|
||||
|
||||
This method is part of the setup phase, just before the test case is run.
|
||||
"""
|
||||
test_items = parse_test_data(case.data, case.name)
|
||||
base_path = case.suite.base_path
|
||||
if case.suite.native_sep:
|
||||
join = os.path.join
|
||||
else:
|
||||
join = posixpath.join # type: ignore
|
||||
|
||||
out_section_missing = case.suite.required_out_section
|
||||
|
||||
files = [] # type: List[Tuple[str, str]] # path and contents
|
||||
output_files = [] # type: List[Tuple[str, str]] # path and contents for output files
|
||||
output = [] # type: List[str] # Regular output errors
|
||||
output2 = {} # type: Dict[int, List[str]] # Output errors for incremental, runs 2+
|
||||
deleted_paths = {} # type: Dict[int, Set[str]] # from run number of paths
|
||||
stale_modules = {} # type: Dict[int, Set[str]] # from run number to module names
|
||||
rechecked_modules = {} # type: Dict[ int, Set[str]] # from run number module names
|
||||
triggered = [] # type: List[str] # Active triggers (one line per incremental step)
|
||||
|
||||
# Process the parsed items. Each item has a header of form [id args],
|
||||
# optionally followed by lines of text.
|
||||
item = first_item = test_items[0]
|
||||
for item in test_items[1:]:
|
||||
if item.id == 'file' or item.id == 'outfile':
|
||||
# Record an extra file needed for the test case.
|
||||
assert item.arg is not None
|
||||
contents = expand_variables('\n'.join(item.data))
|
||||
file_entry = (join(base_path, item.arg), contents)
|
||||
if item.id == 'file':
|
||||
files.append(file_entry)
|
||||
else:
|
||||
output_files.append(file_entry)
|
||||
elif item.id in ('builtins', 'builtins_py2'):
|
||||
# Use an alternative stub file for the builtins module.
|
||||
assert item.arg is not None
|
||||
mpath = join(os.path.dirname(case.file), item.arg)
|
||||
fnam = 'builtins.pyi' if item.id == 'builtins' else '__builtin__.pyi'
|
||||
with open(mpath) as f:
|
||||
files.append((join(base_path, fnam), f.read()))
|
||||
elif item.id == 'typing':
|
||||
# Use an alternative stub file for the typing module.
|
||||
assert item.arg is not None
|
||||
src_path = join(os.path.dirname(case.file), item.arg)
|
||||
with open(src_path) as f:
|
||||
files.append((join(base_path, 'typing.pyi'), f.read()))
|
||||
elif re.match(r'stale[0-9]*$', item.id):
|
||||
passnum = 1 if item.id == 'stale' else int(item.id[len('stale'):])
|
||||
assert passnum > 0
|
||||
modules = (set() if item.arg is None else {t.strip() for t in item.arg.split(',')})
|
||||
stale_modules[passnum] = modules
|
||||
elif re.match(r'rechecked[0-9]*$', item.id):
|
||||
passnum = 1 if item.id == 'rechecked' else int(item.id[len('rechecked'):])
|
||||
assert passnum > 0
|
||||
modules = (set() if item.arg is None else {t.strip() for t in item.arg.split(',')})
|
||||
rechecked_modules[passnum] = modules
|
||||
elif item.id == 'delete':
|
||||
# File to delete during a multi-step test case
|
||||
assert item.arg is not None
|
||||
m = re.match(r'(.*)\.([0-9]+)$', item.arg)
|
||||
assert m, 'Invalid delete section: {}'.format(item.arg)
|
||||
num = int(m.group(2))
|
||||
assert num >= 2, "Can't delete during step {}".format(num)
|
||||
full = join(base_path, m.group(1))
|
||||
deleted_paths.setdefault(num, set()).add(full)
|
||||
elif re.match(r'out[0-9]*$', item.id):
|
||||
tmp_output = [expand_variables(line) for line in item.data]
|
||||
if os.path.sep == '\\':
|
||||
tmp_output = [fix_win_path(line) for line in tmp_output]
|
||||
if item.id == 'out' or item.id == 'out1':
|
||||
output = tmp_output
|
||||
else:
|
||||
passnum = int(item.id[len('out'):])
|
||||
assert passnum > 1
|
||||
output2[passnum] = tmp_output
|
||||
out_section_missing = False
|
||||
elif item.id == 'triggered' and item.arg is None:
|
||||
triggered = item.data
|
||||
elif item.id == 'env':
|
||||
env_vars_to_set = item.arg
|
||||
for env in env_vars_to_set.split(';'):
|
||||
try:
|
||||
name, value = env.split('=')
|
||||
os.environ[name] = value
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
'Invalid section header {} in {} at line {}'.format(
|
||||
item.id, case.file, item.line))
|
||||
|
||||
if out_section_missing:
|
||||
raise ValueError(
|
||||
'{}, line {}: Required output section not found'.format(
|
||||
case.file, first_item.line))
|
||||
|
||||
for passnum in stale_modules.keys():
|
||||
if passnum not in rechecked_modules:
|
||||
# If the set of rechecked modules isn't specified, make it the same as the set
|
||||
# of modules with a stale public interface.
|
||||
rechecked_modules[passnum] = stale_modules[passnum]
|
||||
if (passnum in stale_modules
|
||||
and passnum in rechecked_modules
|
||||
and not stale_modules[passnum].issubset(rechecked_modules[passnum])):
|
||||
raise ValueError(
|
||||
('Stale modules after pass {} must be a subset of rechecked '
|
||||
'modules ({}:{})').format(passnum, case.file, first_item.line))
|
||||
|
||||
input = first_item.data
|
||||
expand_errors(input, output, 'main')
|
||||
for file_path, contents in files:
|
||||
expand_errors(contents.split('\n'), output, file_path)
|
||||
|
||||
case.input = input
|
||||
case.output = output
|
||||
case.output2 = output2
|
||||
case.lastline = item.line
|
||||
case.files = files
|
||||
case.output_files = output_files
|
||||
case.expected_stale_modules = stale_modules
|
||||
case.expected_rechecked_modules = rechecked_modules
|
||||
case.deleted_paths = deleted_paths
|
||||
case.triggered = triggered or []
|
||||
|
||||
|
||||
class DjangoDataDrivenTestCase(DataDrivenTestCase):
|
||||
def setup(self) -> None:
|
||||
self.old_environ = os.environ.copy()
|
||||
|
||||
parse_test_case(case=self)
|
||||
self.old_cwd = os.getcwd()
|
||||
|
||||
self.tmpdir = tempfile.TemporaryDirectory(prefix='mypy-test-')
|
||||
tmpdir_root = os.path.join(self.tmpdir.name, 'tmp')
|
||||
|
||||
new_files = []
|
||||
for path, contents in self.files:
|
||||
new_files.append((path, contents.replace('<TMP>', tmpdir_root)))
|
||||
self.files = new_files
|
||||
|
||||
os.chdir(self.tmpdir.name)
|
||||
os.mkdir(test_temp_dir)
|
||||
encountered_files = set()
|
||||
self.clean_up = []
|
||||
for paths in self.deleted_paths.values():
|
||||
for path in paths:
|
||||
self.clean_up.append((False, path))
|
||||
encountered_files.add(path)
|
||||
for path, content in self.files:
|
||||
dir = os.path.dirname(path)
|
||||
for d in self.add_dirs(dir):
|
||||
self.clean_up.append((True, d))
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
if path not in encountered_files:
|
||||
self.clean_up.append((False, path))
|
||||
encountered_files.add(path)
|
||||
if re.search(r'\.[2-9]$', path):
|
||||
# Make sure new files introduced in the second and later runs are accounted for
|
||||
renamed_path = path[:-2]
|
||||
if renamed_path not in encountered_files:
|
||||
encountered_files.add(renamed_path)
|
||||
self.clean_up.append((False, renamed_path))
|
||||
for path, _ in self.output_files:
|
||||
# Create directories for expected output and mark them to be cleaned up at the end
|
||||
# of the test case.
|
||||
dir = os.path.dirname(path)
|
||||
for d in self.add_dirs(dir):
|
||||
self.clean_up.append((True, d))
|
||||
self.clean_up.append((False, path))
|
||||
|
||||
sys.path.insert(0, tmpdir_root)
|
||||
|
||||
def teardown(self):
|
||||
if hasattr(self, 'old_environ'):
|
||||
os.environ = self.old_environ
|
||||
super().teardown()
|
||||
|
||||
|
||||
def split_test_cases(parent: 'DataSuiteCollector', suite: 'DataSuite',
|
||||
file: str) -> Iterator[DjangoDataDrivenTestCase]:
|
||||
"""Iterate over raw test cases in file, at collection time, ignoring sub items.
|
||||
|
||||
The collection phase is slow, so any heavy processing should be deferred to after
|
||||
uninteresting tests are filtered (when using -k PATTERN switch).
|
||||
"""
|
||||
with open(file, encoding='utf-8') as f:
|
||||
data = f.read()
|
||||
cases = re.split(r'^\[case ([a-zA-Z_0-9]+)'
|
||||
r'(-writescache)?'
|
||||
r'(-only_when_cache|-only_when_nocache)?'
|
||||
r'(-skip)?'
|
||||
r'\][ \t]*$\n', data,
|
||||
flags=re.DOTALL | re.MULTILINE)
|
||||
line_no = cases[0].count('\n') + 1
|
||||
|
||||
for i in range(1, len(cases), 5):
|
||||
name, writescache, only_when, skip, data = cases[i:i + 5]
|
||||
yield DjangoDataDrivenTestCase(parent, suite, file,
|
||||
name=add_test_name_suffix(name, suite.test_name_suffix),
|
||||
writescache=bool(writescache),
|
||||
only_when=only_when,
|
||||
skip=bool(skip),
|
||||
data=data,
|
||||
line=line_no)
|
||||
line_no += data.count('\n') + 1
|
||||
|
||||
|
||||
class DataSuiteCollector(pytest.Class): # type: ignore # inheriting from Any
|
||||
def collect(self) -> Iterator[pytest.Item]: # type: ignore
|
||||
"""Called by pytest on each of the object returned from pytest_pycollect_makeitem"""
|
||||
|
||||
# obj is the object for which pytest_pycollect_makeitem returned self.
|
||||
suite = self.obj # type: DataSuite
|
||||
for f in suite.files:
|
||||
yield from split_test_cases(self, suite, os.path.join(suite.data_prefix, f))
|
||||
|
||||
|
||||
# This function name is special to pytest. See
|
||||
# http://doc.pytest.org/en/latest/writing_plugins.html#collection-hooks
|
||||
def pytest_pycollect_makeitem(collector: Any, name: str,
|
||||
obj: object) -> 'Optional[Any]':
|
||||
"""Called by pytest on each object in modules configured in conftest.py files.
|
||||
|
||||
collector is pytest.Collector, returns Optional[pytest.Class]
|
||||
"""
|
||||
if isinstance(obj, type):
|
||||
# Only classes derived from DataSuite contain test cases, not the DataSuite class itself
|
||||
if issubclass(obj, DataSuite) and obj is not DataSuite:
|
||||
# Non-None result means this obj is a test case.
|
||||
# The collect method of the returned DataSuiteCollector instance will be called later,
|
||||
# with self.obj being obj.
|
||||
return DataSuiteCollector(name, parent=collector)
|
||||
return None
|
||||
237
test/helpers.py
237
test/helpers.py
@@ -1,237 +0,0 @@
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
from typing import List, Callable, Optional, Tuple
|
||||
|
||||
import pytest # type: ignore # no pytest in typeshed
|
||||
|
||||
skip = pytest.mark.skip
|
||||
|
||||
# AssertStringArraysEqual displays special line alignment helper messages if
|
||||
# the first different line has at least this many characters,
|
||||
MIN_LINE_LENGTH_FOR_ALIGNMENT = 5
|
||||
|
||||
|
||||
class TypecheckAssertionError(AssertionError):
|
||||
def __init__(self, error_message: str, lineno: int):
|
||||
self.error_message = error_message
|
||||
self.lineno = lineno
|
||||
|
||||
def first_line(self):
|
||||
return self.__class__.__name__ + '(message="Invalid output")'
|
||||
|
||||
def __str__(self):
|
||||
return self.error_message
|
||||
|
||||
|
||||
def _clean_up(a: List[str]) -> List[str]:
|
||||
"""Remove common directory prefix from all strings in a.
|
||||
|
||||
This uses a naive string replace; it seems to work well enough. Also
|
||||
remove trailing carriage returns.
|
||||
"""
|
||||
res = []
|
||||
for s in a:
|
||||
prefix = os.sep
|
||||
ss = s
|
||||
for p in prefix, prefix.replace(os.sep, '/'):
|
||||
if p != '/' and p != '//' and p != '\\' and p != '\\\\':
|
||||
ss = ss.replace(p, '')
|
||||
# Ignore spaces at end of line.
|
||||
ss = re.sub(' +$', '', ss)
|
||||
res.append(re.sub('\\r$', '', ss))
|
||||
return res
|
||||
|
||||
|
||||
def _num_skipped_prefix_lines(a1: List[str], a2: List[str]) -> int:
|
||||
num_eq = 0
|
||||
while num_eq < min(len(a1), len(a2)) and a1[num_eq] == a2[num_eq]:
|
||||
num_eq += 1
|
||||
return max(0, num_eq - 4)
|
||||
|
||||
|
||||
def _num_skipped_suffix_lines(a1: List[str], a2: List[str]) -> int:
|
||||
num_eq = 0
|
||||
while (num_eq < min(len(a1), len(a2))
|
||||
and a1[-num_eq - 1] == a2[-num_eq - 1]):
|
||||
num_eq += 1
|
||||
return max(0, num_eq - 4)
|
||||
|
||||
|
||||
def _add_aligned_message(s1: str, s2: str, error_message: str) -> str:
|
||||
"""Align s1 and s2 so that the their first difference is highlighted.
|
||||
|
||||
For example, if s1 is 'foobar' and s2 is 'fobar', display the
|
||||
following lines:
|
||||
|
||||
E: foobar
|
||||
A: fobar
|
||||
^
|
||||
|
||||
If s1 and s2 are long, only display a fragment of the strings around the
|
||||
first difference. If s1 is very short, do nothing.
|
||||
"""
|
||||
|
||||
# Seeing what went wrong is trivial even without alignment if the expected
|
||||
# string is very short. In this case do nothing to simplify output.
|
||||
if len(s1) < 4:
|
||||
return error_message
|
||||
|
||||
maxw = 72 # Maximum number of characters shown
|
||||
|
||||
error_message += 'Alignment of first line difference:\n'
|
||||
# sys.stderr.write('Alignment of first line difference:\n')
|
||||
|
||||
trunc = False
|
||||
while s1[:30] == s2[:30]:
|
||||
s1 = s1[10:]
|
||||
s2 = s2[10:]
|
||||
trunc = True
|
||||
|
||||
if trunc:
|
||||
s1 = '...' + s1
|
||||
s2 = '...' + s2
|
||||
|
||||
max_len = max(len(s1), len(s2))
|
||||
extra = ''
|
||||
if max_len > maxw:
|
||||
extra = '...'
|
||||
|
||||
# Write a chunk of both lines, aligned.
|
||||
error_message += ' E: {}{}\n'.format(s1[:maxw], extra)
|
||||
# sys.stderr.write(' E: {}{}\n'.format(s1[:maxw], extra))
|
||||
error_message += ' A: {}{}\n'.format(s2[:maxw], extra)
|
||||
# sys.stderr.write(' A: {}{}\n'.format(s2[:maxw], extra))
|
||||
# Write an indicator character under the different columns.
|
||||
error_message += ' '
|
||||
# sys.stderr.write(' ')
|
||||
for j in range(min(maxw, max(len(s1), len(s2)))):
|
||||
if s1[j:j + 1] != s2[j:j + 1]:
|
||||
error_message += '^'
|
||||
# sys.stderr.write('^') # Difference
|
||||
break
|
||||
else:
|
||||
error_message += ' '
|
||||
# sys.stderr.write(' ') # Equal
|
||||
error_message += '\n'
|
||||
return error_message
|
||||
# sys.stderr.write('\n')
|
||||
|
||||
|
||||
def assert_string_arrays_equal(expected: List[str], actual: List[str]) -> None:
|
||||
"""Assert that two string arrays are equal.
|
||||
|
||||
Display any differences in a human-readable form.
|
||||
"""
|
||||
|
||||
actual = _clean_up(actual)
|
||||
error_message = ''
|
||||
|
||||
if set(actual) != set(expected):
|
||||
num_skip_start = _num_skipped_prefix_lines(expected, actual)
|
||||
num_skip_end = _num_skipped_suffix_lines(expected, actual)
|
||||
|
||||
error_message += 'Expected:\n'
|
||||
|
||||
# If omit some lines at the beginning, indicate it by displaying a line
|
||||
# with '...'.
|
||||
if num_skip_start > 0:
|
||||
error_message += ' ...\n'
|
||||
|
||||
# Keep track of the first different line.
|
||||
first_diff = -1
|
||||
|
||||
# Display only this many first characters of identical lines.
|
||||
width = 75
|
||||
|
||||
for i in range(num_skip_start, len(expected) - num_skip_end):
|
||||
if i >= len(actual) or expected[i] != actual[i]:
|
||||
if first_diff < 0:
|
||||
first_diff = i
|
||||
error_message += ' {:<45} (diff)'.format(expected[i])
|
||||
else:
|
||||
e = expected[i]
|
||||
error_message += ' ' + e[:width]
|
||||
if len(e) > width:
|
||||
error_message += '...'
|
||||
error_message += '\n'
|
||||
if num_skip_end > 0:
|
||||
error_message += ' ...\n'
|
||||
|
||||
error_message += 'Actual:\n'
|
||||
|
||||
if num_skip_start > 0:
|
||||
error_message += ' ...\n'
|
||||
|
||||
for j in range(num_skip_start, len(actual) - num_skip_end):
|
||||
if j >= len(expected) or expected[j] != actual[j]:
|
||||
error_message += ' {:<45} (diff)'.format(actual[j])
|
||||
else:
|
||||
a = actual[j]
|
||||
error_message += ' ' + a[:width]
|
||||
if len(a) > width:
|
||||
error_message += '...'
|
||||
error_message += '\n'
|
||||
if actual == []:
|
||||
error_message += ' (empty)\n'
|
||||
if num_skip_end > 0:
|
||||
error_message += ' ...\n'
|
||||
|
||||
error_message += '\n'
|
||||
|
||||
if 0 <= first_diff < len(actual) and (
|
||||
len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
|
||||
or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT):
|
||||
# Display message that helps visualize the differences between two
|
||||
# long lines.
|
||||
error_message = _add_aligned_message(expected[first_diff], actual[first_diff],
|
||||
error_message)
|
||||
|
||||
first_failure = expected[first_diff]
|
||||
if first_failure:
|
||||
lineno = int(first_failure.split(' ')[0].strip(':').split(':')[1])
|
||||
raise TypecheckAssertionError(error_message=f'Invalid output: \n{error_message}',
|
||||
lineno=lineno)
|
||||
|
||||
|
||||
def build_output_line(fname: str, lnum: int, severity: str, message: str, col=None) -> str:
|
||||
if col is None:
|
||||
return f'{fname}:{lnum + 1}: {severity}: {message}'
|
||||
else:
|
||||
return f'{fname}:{lnum + 1}:{col}: {severity}: {message}'
|
||||
|
||||
|
||||
def expand_errors(input_lines: List[str], fname: str) -> List[str]:
|
||||
"""Transform comments such as '# E: message' or
|
||||
'# E:3: message' in input.
|
||||
|
||||
The result is lines like 'fnam:line: error: message'.
|
||||
"""
|
||||
output_lines = []
|
||||
for lnum, line in enumerate(input_lines):
|
||||
# The first in the split things isn't a comment
|
||||
for possible_err_comment in line.split(' # ')[1:]:
|
||||
m = re.search(
|
||||
r'^([ENW]):((?P<col>\d+):)? (?P<message>.*)$',
|
||||
possible_err_comment.strip())
|
||||
if m:
|
||||
if m.group(1) == 'E':
|
||||
severity = 'error'
|
||||
elif m.group(1) == 'N':
|
||||
severity = 'note'
|
||||
elif m.group(1) == 'W':
|
||||
severity = 'warning'
|
||||
col = m.group('col')
|
||||
output_lines.append(build_output_line(fname, lnum, severity,
|
||||
message=m.group("message"),
|
||||
col=col))
|
||||
return output_lines
|
||||
|
||||
|
||||
def get_func_first_lnum(attr: Callable[..., None]) -> Optional[Tuple[int, List[str]]]:
|
||||
lines, _ = inspect.getsourcelines(attr)
|
||||
for lnum, line in enumerate(lines):
|
||||
no_space_line = line.strip()
|
||||
if f'def {attr.__name__}' in no_space_line:
|
||||
return lnum, lines[lnum + 1:]
|
||||
raise ValueError(f'No line "def {attr.__name__}" found')
|
||||
@@ -1,303 +0,0 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Any, Optional, cast, List, Type, Callable, Dict
|
||||
|
||||
import pytest
|
||||
from _pytest._code.code import ReprFileLocation, ReprEntry, ExceptionInfo
|
||||
from decorator import decorate
|
||||
from mypy import api as mypy_api
|
||||
|
||||
from test import vistir
|
||||
from test.helpers import assert_string_arrays_equal, TypecheckAssertionError, expand_errors, get_func_first_lnum
|
||||
|
||||
|
||||
def reveal_type(obj: Any) -> None:
|
||||
# noop method, just to get rid of "method is not resolved" errors
|
||||
pass
|
||||
|
||||
|
||||
def output(output_lines: str):
|
||||
def decor(func: Callable[..., None]):
|
||||
func.out = output_lines.strip()
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorate(func, wrapper)
|
||||
|
||||
return decor
|
||||
|
||||
|
||||
def get_class_that_defined_method(meth) -> Type['MypyTypecheckTestCase']:
|
||||
if inspect.ismethod(meth):
|
||||
for cls in inspect.getmro(meth.__self__.__class__):
|
||||
if cls.__dict__.get(meth.__name__) is meth:
|
||||
return cls
|
||||
meth = meth.__func__ # fallback to __qualname__ parsing
|
||||
if inspect.isfunction(meth):
|
||||
cls = getattr(inspect.getmodule(meth),
|
||||
meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0])
|
||||
if issubclass(cls, MypyTypecheckTestCase):
|
||||
return cls
|
||||
return getattr(meth, '__objclass__', None) # handle special descriptor objects
|
||||
|
||||
|
||||
def file(filename: str, make_parent_packages=False):
|
||||
def decor(func: Callable[..., None]):
|
||||
func.filename = filename
|
||||
func.make_parent_packages = make_parent_packages
|
||||
return func
|
||||
|
||||
return decor
|
||||
|
||||
|
||||
def env(**environ):
|
||||
def decor(func: Callable[..., None]):
|
||||
func.env = environ
|
||||
return func
|
||||
|
||||
return decor
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CreateFile:
|
||||
sources: str
|
||||
make_parent_packages: bool = False
|
||||
|
||||
|
||||
class MypyTypecheckMeta(type):
|
||||
def __new__(mcs, name, bases, attrs):
|
||||
cls = super().__new__(mcs, name, bases, attrs)
|
||||
cls.files: Dict[str, CreateFile] = {}
|
||||
|
||||
for name, attr in attrs.items():
|
||||
if inspect.isfunction(attr):
|
||||
filename = getattr(attr, 'filename', None)
|
||||
if not filename:
|
||||
continue
|
||||
make_parent_packages = getattr(attr, 'make_parent_packages', False)
|
||||
sources = textwrap.dedent(''.join(get_func_first_lnum(attr)[1]))
|
||||
if sources.strip() == 'pass':
|
||||
sources = ''
|
||||
cls.files[filename] = CreateFile(sources, make_parent_packages)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
class MypyTypecheckTestCase(metaclass=MypyTypecheckMeta):
|
||||
files = None
|
||||
|
||||
def ini_file(self) -> str:
|
||||
return """
|
||||
[mypy]
|
||||
"""
|
||||
|
||||
def _get_ini_file_contents(self) -> Optional[str]:
|
||||
raw_ini_file = self.ini_file()
|
||||
if not raw_ini_file:
|
||||
return raw_ini_file
|
||||
return raw_ini_file.strip() + '\n'
|
||||
|
||||
|
||||
class TraceLastReprEntry(ReprEntry):
|
||||
def toterminal(self, tw):
|
||||
self.reprfileloc.toterminal(tw)
|
||||
for line in self.lines:
|
||||
red = line.startswith("E ")
|
||||
tw.line(line, bold=True, red=red)
|
||||
return
|
||||
|
||||
|
||||
def fname_to_module(fpath: Path, root_path: Path) -> Optional[str]:
|
||||
try:
|
||||
relpath = fpath.relative_to(root_path).with_suffix('')
|
||||
return str(relpath).replace(os.sep, '.')
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class MypyTypecheckItem(pytest.Item):
|
||||
root_directory = '/tmp'
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
parent: 'MypyTestsCollector',
|
||||
klass: Type[MypyTypecheckTestCase],
|
||||
source_code: str,
|
||||
first_lineno: int,
|
||||
ini_file_contents: Optional[str] = None,
|
||||
expected_output_lines: Optional[List[str]] = None,
|
||||
files: Optional[Dict[str, CreateFile]] = None,
|
||||
custom_environment: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(name=name, parent=parent)
|
||||
self.klass = klass
|
||||
self.source_code = source_code
|
||||
self.first_lineno = first_lineno
|
||||
self.ini_file_contents = ini_file_contents
|
||||
self.expected_output_lines = expected_output_lines
|
||||
self.files = files
|
||||
self.custom_environment = custom_environment
|
||||
|
||||
@contextmanager
|
||||
def temp_directory(self) -> Path:
|
||||
with tempfile.TemporaryDirectory(prefix='mypy-pytest-',
|
||||
dir=self.root_directory) as tmpdir_name:
|
||||
yield Path(self.root_directory) / tmpdir_name
|
||||
|
||||
def runtest(self):
|
||||
with self.temp_directory() as tmpdir_path:
|
||||
if not self.source_code:
|
||||
return
|
||||
|
||||
if self.ini_file_contents:
|
||||
mypy_ini_fpath = tmpdir_path / 'mypy.ini'
|
||||
mypy_ini_fpath.write_text(self.ini_file_contents)
|
||||
|
||||
test_specific_modules = []
|
||||
for fname, create_file in self.files.items():
|
||||
fpath = tmpdir_path / fname
|
||||
if create_file.make_parent_packages:
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
for parent in fpath.parents:
|
||||
try:
|
||||
parent.relative_to(tmpdir_path)
|
||||
if parent != tmpdir_path:
|
||||
parent_init_file = parent / '__init__.py'
|
||||
parent_init_file.write_text('')
|
||||
test_specific_modules.append(fname_to_module(parent,
|
||||
root_path=tmpdir_path))
|
||||
except ValueError:
|
||||
break
|
||||
|
||||
fpath.write_text(create_file.sources)
|
||||
test_specific_modules.append(fname_to_module(fpath,
|
||||
root_path=tmpdir_path))
|
||||
|
||||
with vistir.temp_environ(), vistir.temp_path():
|
||||
for key, val in (self.custom_environment or {}).items():
|
||||
os.environ[key] = val
|
||||
sys.path.insert(0, str(tmpdir_path))
|
||||
|
||||
mypy_cmd_options = self.prepare_mypy_cmd_options(config_file_path=mypy_ini_fpath)
|
||||
main_fpath = tmpdir_path / 'main.py'
|
||||
main_fpath.write_text(self.source_code)
|
||||
mypy_cmd_options.append(str(main_fpath))
|
||||
|
||||
stdout, stderr, returncode = mypy_api.run(mypy_cmd_options)
|
||||
output_lines = []
|
||||
for line in (stdout + stderr).splitlines():
|
||||
if ':' not in line:
|
||||
continue
|
||||
out_fpath, res_line = line.split(':', 1)
|
||||
line = os.path.relpath(out_fpath, start=tmpdir_path) + ':' + res_line
|
||||
output_lines.append(line.strip().replace('.py', ''))
|
||||
|
||||
for module in test_specific_modules:
|
||||
parts = module.split('.')
|
||||
for i in range(len(parts)):
|
||||
parent_module = '.'.join(parts[:i + 1])
|
||||
if parent_module in sys.modules:
|
||||
del sys.modules[parent_module]
|
||||
|
||||
assert_string_arrays_equal(expected=self.expected_output_lines,
|
||||
actual=output_lines)
|
||||
|
||||
def prepare_mypy_cmd_options(self, config_file_path: Path) -> List[str]:
|
||||
mypy_cmd_options = [
|
||||
'--raise-exceptions',
|
||||
'--no-silence-site-packages'
|
||||
]
|
||||
python_version = '.'.join([str(part) for part in sys.version_info[:2]])
|
||||
mypy_cmd_options.append(f'--python-version={python_version}')
|
||||
if self.ini_file_contents:
|
||||
mypy_cmd_options.append(f'--config-file={config_file_path}')
|
||||
return mypy_cmd_options
|
||||
|
||||
def repr_failure(self, excinfo: ExceptionInfo) -> str:
|
||||
if excinfo.errisinstance(SystemExit):
|
||||
# We assume that before doing exit() (which raises SystemExit) we've printed
|
||||
# enough context about what happened so that a stack trace is not useful.
|
||||
# In particular, uncaught exceptions during semantic analysis or type checking
|
||||
# call exit() and they already print out a stack trace.
|
||||
return excinfo.exconly(tryshort=True)
|
||||
elif excinfo.errisinstance(TypecheckAssertionError):
|
||||
# with traceback removed
|
||||
exception_repr = excinfo.getrepr(style='short')
|
||||
exception_repr.reprcrash.message = ''
|
||||
repr_file_location = ReprFileLocation(path=inspect.getfile(self.klass),
|
||||
lineno=self.first_lineno + excinfo.value.lineno,
|
||||
message='')
|
||||
repr_tb_entry = TraceLastReprEntry(filelocrepr=repr_file_location,
|
||||
lines=exception_repr.reprtraceback.reprentries[-1].lines[1:],
|
||||
style='short',
|
||||
reprlocals=None,
|
||||
reprfuncargs=None)
|
||||
exception_repr.reprtraceback.reprentries = [repr_tb_entry]
|
||||
return exception_repr
|
||||
else:
|
||||
return super().repr_failure(excinfo, style='native')
|
||||
|
||||
def reportinfo(self):
|
||||
return self.fspath, None, get_class_qualname(self.klass) + '::' + self.name
|
||||
|
||||
|
||||
def get_class_qualname(klass: type) -> str:
|
||||
return klass.__module__ + '.' + klass.__name__
|
||||
|
||||
|
||||
def extract_test_output(attr: Callable[..., None]) -> List[str]:
|
||||
out_data: str = getattr(attr, 'out', None)
|
||||
out_lines = []
|
||||
if out_data:
|
||||
for line in out_data.strip().split('\n'):
|
||||
if line:
|
||||
line = line.strip()
|
||||
out_lines.append(line)
|
||||
return out_lines
|
||||
|
||||
|
||||
class MypyTestsCollector(pytest.Class):
|
||||
def get_ini_file_contents(self, contents: str) -> str:
|
||||
return contents.strip() + '\n'
|
||||
|
||||
def collect(self) -> Iterator[pytest.Item]:
|
||||
current_testcase = cast(MypyTypecheckTestCase, self.obj())
|
||||
ini_file_contents = self.get_ini_file_contents(current_testcase.ini_file())
|
||||
for attr_name in dir(current_testcase):
|
||||
if attr_name.startswith('test_'):
|
||||
attr = getattr(self.obj, attr_name)
|
||||
if inspect.isfunction(attr):
|
||||
first_line_lnum, source_lines = get_func_first_lnum(attr)
|
||||
func_first_line_in_file = inspect.getsourcelines(attr)[1] + first_line_lnum
|
||||
|
||||
output_from_decorator = extract_test_output(attr)
|
||||
output_from_comments = expand_errors(source_lines, 'main')
|
||||
custom_env = getattr(attr, 'env', None)
|
||||
main_source_code = textwrap.dedent(''.join(source_lines))
|
||||
yield MypyTypecheckItem(name=attr_name,
|
||||
parent=self,
|
||||
klass=current_testcase.__class__,
|
||||
source_code=main_source_code,
|
||||
first_lineno=func_first_line_in_file,
|
||||
ini_file_contents=ini_file_contents,
|
||||
expected_output_lines=output_from_comments
|
||||
+ output_from_decorator,
|
||||
files=current_testcase.__class__.files,
|
||||
custom_environment=custom_env)
|
||||
|
||||
|
||||
def pytest_pycollect_makeitem(collector: Any, name: str, obj: Any) -> Optional[MypyTestsCollector]:
|
||||
# Only classes derived from DataSuite contain test cases, not the DataSuite class itself
|
||||
if (isinstance(obj, type)
|
||||
and issubclass(obj, MypyTypecheckTestCase)
|
||||
and obj is not MypyTypecheckTestCase):
|
||||
# Non-None result means this obj is a test case.
|
||||
# The collect method of the returned DataSuiteCollector instance will be called later,
|
||||
# with self.obj being obj.
|
||||
return MypyTestsCollector(name, parent=collector)
|
||||
@@ -1,21 +0,0 @@
|
||||
[CASE array_field_descriptor_access]
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
||||
class User(models.Model):
|
||||
array = ArrayField(base_field=models.Field())
|
||||
|
||||
user = User()
|
||||
reveal_type(user.array) # E: Revealed type is 'builtins.list[Any]'
|
||||
|
||||
[CASE array_field_base_field_parsed_into_generic_typevar]
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
||||
class User(models.Model):
|
||||
members = ArrayField(base_field=models.IntegerField())
|
||||
members_as_text = ArrayField(base_field=models.CharField(max_length=255))
|
||||
|
||||
user = User()
|
||||
reveal_type(user.members) # E: Revealed type is 'builtins.list[builtins.int*]'
|
||||
reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list[builtins.str*]'
|
||||
@@ -1,60 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from mypy import api
|
||||
from mypy.test.config import test_temp_dir
|
||||
from mypy.test.data import DataSuite, DataDrivenTestCase
|
||||
from mypy.test.helpers import assert_string_arrays_equal
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent
|
||||
TEST_DATA_DIR = ROOT_DIR / 'test' / 'test-data'
|
||||
MYPY_INI_PATH = ROOT_DIR / 'test' / 'plugins.ini'
|
||||
|
||||
|
||||
class DjangoTestSuite(DataSuite):
|
||||
files = [
|
||||
# 'check-objects-queryset.test',
|
||||
# 'check-model-fields.test',
|
||||
# 'check-postgres-fields.test',
|
||||
# 'check-model-relations.test',
|
||||
# 'check-parse-settings.test',
|
||||
# 'check-to-attr-as-string-one-to-one-field.test',
|
||||
'check-to-attr-as-string-foreign-key.test',
|
||||
# 'check-foreign-key-as-string-creates-underscore-id-attr.test'
|
||||
]
|
||||
data_prefix = str(TEST_DATA_DIR)
|
||||
|
||||
def run_case(self, testcase: DataDrivenTestCase) -> None:
|
||||
assert testcase.old_cwd is not None, "test was not properly set up"
|
||||
|
||||
mypy_cmdline = [
|
||||
'--show-traceback',
|
||||
'--no-silence-site-packages',
|
||||
'--config-file={}'.format(MYPY_INI_PATH)
|
||||
]
|
||||
mypy_cmdline.append('--python-version={}'.format('.'.join(map(str,
|
||||
sys.version_info[:2]))))
|
||||
|
||||
program_path = os.path.join(test_temp_dir, 'main.py')
|
||||
mypy_cmdline.append(program_path)
|
||||
|
||||
with open(program_path, 'w') as file:
|
||||
for s in testcase.input:
|
||||
file.write('{}\n'.format(s))
|
||||
|
||||
output = []
|
||||
# Type check the program.
|
||||
out, err, returncode = api.run(mypy_cmdline)
|
||||
# split lines, remove newlines, and remove directory of test case
|
||||
for line in (out + err).splitlines():
|
||||
if line.startswith(test_temp_dir + os.sep):
|
||||
output.append(line[len(test_temp_dir + os.sep):].rstrip("\r\n").replace('.py', ''))
|
||||
else:
|
||||
output.append(line.rstrip("\r\n"))
|
||||
# Remove temp file.
|
||||
os.remove(program_path)
|
||||
|
||||
assert_string_arrays_equal(testcase.output, output,
|
||||
'Invalid output ({}, line {})'.format(
|
||||
testcase.file, testcase.line))
|
||||
@@ -1,43 +0,0 @@
|
||||
# Borrowed from Pew.
|
||||
# See https://github.com/berdario/pew/blob/master/pew/_utils.py#L82
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from decorator import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temp_environ():
|
||||
"""Allow the ability to set os.environ temporarily"""
|
||||
environ = dict(os.environ)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.environ.clear()
|
||||
os.environ.update(environ)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temp_path():
|
||||
"""A context manager which allows the ability to set sys.path temporarily"""
|
||||
path = [p for p in sys.path]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.path = [p for p in path]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cd(path):
|
||||
"""Context manager to temporarily change working directories"""
|
||||
if not path:
|
||||
return
|
||||
prev_cwd = Path.cwd().as_posix()
|
||||
if isinstance(path, Path):
|
||||
path = path.as_posix()
|
||||
os.chdir(str(path))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(prev_cwd)
|
||||
Reference in New Issue
Block a user