preliminary support for strict_optional

This commit is contained in:
Maxim Kurnikov
2019-02-17 18:07:53 +03:00
parent 6763217a80
commit e9f9202ed1
23 changed files with 614 additions and 343 deletions

View File

@@ -0,0 +1,199 @@
from typing import Optional, cast
from mypy.checker import TypeChecker
from mypy.nodes import ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo, Var
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, TupleType, Type, TypeOfAny, UnionType
from mypy_django_plugin import helpers
from mypy_django_plugin.transformers.models import iter_over_assignments
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.callee_arg_names:
api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
context=ctx.context)
return None
arg_type = ctx.arg_types[ctx.callee_arg_names.index('to')][0]
if not isinstance(arg_type, CallableType):
to_arg_expr = ctx.args[ctx.callee_arg_names.index('to')][0]
if not isinstance(to_arg_expr, StrExpr):
# not string, not supported
return None
try:
model_fullname = helpers.get_model_fullname_from_string(to_arg_expr.value,
all_modules=api.modules)
except helpers.SelfReference:
model_fullname = api.tscope.classes[-1].fullname()
if model_fullname is None:
return None
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
all_modules=api.modules)
if model_info is None or not isinstance(model_info, TypeInfo):
return None
return Instance(model_info, [])
referred_to_type = arg_type.ret_type
if not isinstance(referred_to_type, Instance):
return None
if not referred_to_type.type.has_base(helpers.MODEL_CLASS_FULLNAME):
ctx.api.msg.fail(f'to= parameter value must be '
f'a subclass of {helpers.MODEL_CLASS_FULLNAME}',
context=ctx.context)
return None
return referred_to_type
def convert_any_to_type(typ: Type, referred_to_type: Type) -> Type:
if isinstance(typ, UnionType):
converted_items = []
for item in typ.items:
converted_items.append(convert_any_to_type(item, referred_to_type))
return UnionType.make_simplified_union(converted_items,
line=typ.line, column=typ.column)
if isinstance(typ, Instance):
args = []
for default_arg in typ.args:
if isinstance(default_arg, AnyType):
args.append(referred_to_type)
else:
args.append(default_arg)
return helpers.reparametrize_instance(typ, args)
if isinstance(typ, AnyType):
return referred_to_type
return typ
def _extract_referred_to_type(ctx: FunctionContext) -> Optional[Type]:
try:
referred_to_type = get_valid_to_value_or_none(ctx)
except helpers.InvalidModelString as exc:
ctx.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', ctx.context)
return None
return referred_to_type
def fill_descriptor_types_for_related_field(ctx: FunctionContext) -> Type:
default_return_type = set_descriptor_types_for_field(ctx)
referred_to_type = _extract_referred_to_type(ctx)
if referred_to_type is None:
return default_return_type
# replace Any with referred_to_type
args = []
for default_arg in default_return_type.args:
args.append(convert_any_to_type(default_arg, referred_to_type))
return helpers.reparametrize_instance(ctx.default_return_type, new_args=args)
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> Type:
node = type_info.get(private_field_name).node
if isinstance(node, Var):
descriptor_type = node.type
if is_nullable:
descriptor_type = helpers.make_optional(descriptor_type)
return descriptor_type
return AnyType(TypeOfAny.unannotated)
def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
default_return_type = cast(Instance, ctx.default_return_type)
is_nullable = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'null'))
set_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_set_type',
is_nullable=is_nullable)
get_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_get_type',
is_nullable=is_nullable)
return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
default_return_type = set_descriptor_types_for_field(ctx)
base_field_arg_type = helpers.get_argument_type_by_name(ctx, 'base_field')
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
return default_return_type
base_type = base_field_arg_type.args[1] # extract __get__ type
args = []
for default_arg in default_return_type.args:
args.append(convert_any_to_type(default_arg, base_type))
return helpers.reparametrize_instance(default_return_type, args)
def transform_into_proper_return_type(ctx: FunctionContext) -> Type:
default_return_type = ctx.default_return_type
if not isinstance(default_return_type, Instance):
return default_return_type
if helpers.has_any_of_bases(default_return_type.type, (helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME,
helpers.MANYTOMANY_FIELD_FULLNAME)):
return fill_descriptor_types_for_related_field(ctx)
if default_return_type.type.has_base(helpers.ARRAY_FIELD_FULLNAME):
return determine_type_of_array_field(ctx)
return set_descriptor_types_for_field(ctx)
def adjust_return_type_of_field_instantiation(ctx: FunctionContext) -> Type:
record_field_properties_into_outer_model_class(ctx)
return transform_into_proper_return_type(ctx)
def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> None:
api = cast(TypeChecker, ctx.api)
outer_model = api.scope.active_class()
if outer_model is None or not outer_model.has_base(helpers.MODEL_CLASS_FULLNAME):
# outside models.Model class, undetermined
return
field_name = None
for name_expr, stmt in iter_over_assignments(outer_model.defn):
if stmt == ctx.context and isinstance(name_expr, NameExpr):
field_name = name_expr.name
break
if field_name is None:
return
fields_metadata = outer_model.metadata.setdefault('django', {}).setdefault('fields', {})
# primary key
is_primary_key = False
primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key')
if primary_key_arg:
is_primary_key = helpers.parse_bool(primary_key_arg)
fields_metadata[field_name] = {'primary_key': is_primary_key}
# choices
choices_arg = helpers.get_argument_by_name(ctx, 'choices')
if choices_arg and isinstance(choices_arg, (TupleExpr, ListExpr)):
# iterable of 2 element tuples of two kinds
_, analyzed_choices = api.analyze_iterable_item_type(choices_arg)
if isinstance(analyzed_choices, TupleType):
first_element_type = analyzed_choices.items[0]
if isinstance(first_element_type, Instance):
fields_metadata[field_name]['choices'] = first_element_type.type.fullname()
# nullability
null_arg = helpers.get_argument_by_name(ctx, 'null')
is_nullable = False
if null_arg:
is_nullable = helpers.parse_bool(null_arg)
fields_metadata[field_name]['null'] = is_nullable
# is_blankable
blank_arg = helpers.get_argument_by_name(ctx, 'blank')
is_blankable = False
if blank_arg:
is_blankable = helpers.parse_bool(blank_arg)
fields_metadata[field_name]['blank'] = is_blankable

View File

@@ -0,0 +1,166 @@
from typing import Dict, Optional, Set, cast
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo, Var
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance, Type, TypeOfAny
from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import extract_field_setter_type, extract_explicit_set_type_of_model_primary_key, get_fields_metadata
from mypy_django_plugin.transformers.fields import get_private_descriptor_type
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
pointer_args: Set[str] = set()
for base in model.bases:
if base.type.has_base(helpers.MODEL_CLASS_FULLNAME):
parent_name = base.type.name().lower()
pointer_args.add(f'{parent_name}_ptr')
pointer_args.add(f'{parent_name}_ptr_id')
return pointer_args
def redefine_and_typecheck_model_init(ctx: FunctionContext) -> Type:
assert isinstance(ctx.default_return_type, Instance)
api = cast(TypeChecker, ctx.api)
model: TypeInfo = ctx.default_return_type.type
expected_types = extract_expected_types(ctx, model)
# order is preserved, can use for positionals
positional_names = list(expected_types.keys())
positional_names.remove('pk')
visited_positionals = set()
# check positionals
for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
actual_pos_name = positional_names[i]
api.check_subtype(actual_pos_type, expected_types[actual_pos_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_pos_name,
model.name()),
'got', 'expected')
visited_positionals.add(actual_pos_name)
# extract name of base models for _ptr
base_pointer_args = extract_base_pointer_args(model)
# check kwargs
for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])):
if actual_name in base_pointer_args:
# parent_ptr args are not supported
continue
if actual_name in visited_positionals:
continue
if actual_name is None:
# unpacked dict as kwargs is not supported
continue
if actual_name not in expected_types:
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model.name()),
ctx.context)
continue
api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model.name()),
'got', 'expected')
return ctx.default_return_type
def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type:
api = cast(TypeChecker, ctx.api)
if isinstance(ctx.type, Instance) and len(ctx.type.args) > 0:
model_generic_arg = ctx.type.args[0]
else:
model_generic_arg = ctx.default_return_type
if isinstance(model_generic_arg, AnyType):
return ctx.default_return_type
model: TypeInfo = model_generic_arg.type
# extract name of base models for _ptr
base_pointer_args = extract_base_pointer_args(model)
expected_types = extract_expected_types(ctx, model)
for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]):
if actual_name in base_pointer_args:
# parent_ptr args are not supported
continue
if actual_name is None:
# unpacked dict as kwargs is not supported
continue
if actual_name not in expected_types:
api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model.name()),
ctx.context)
continue
api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model.name()),
'got', 'expected')
return ctx.default_return_type
def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
field_metadata = get_fields_metadata(model).get(field_name, {})
if 'choices' in field_metadata:
return field_metadata['choices']
return None
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
api = cast(TypeChecker, ctx.api)
expected_types: Dict[str, Type] = {}
primary_key_type = extract_explicit_set_type_of_model_primary_key(model)
if not primary_key_type:
# no explicit primary key, set pk to Any and add id
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
expected_types['pk'] = primary_key_type
for base in model.mro:
# extract all fields for all models in MRO
for name, sym in base.names.items():
# do not redefine special attrs
if name in {'_meta', 'pk'}:
continue
if isinstance(sym.node, Var):
typ = sym.node.type
if typ is None or isinstance(typ, AnyType):
# types are not ready, fallback to Any
expected_types[name] = AnyType(TypeOfAny.from_unimported_type)
expected_types[name + '_id'] = AnyType(TypeOfAny.from_unimported_type)
elif isinstance(typ, Instance):
field_type = extract_field_setter_type(typ)
if field_type is None:
continue
if typ.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
primary_key_type = AnyType(TypeOfAny.implementation_artifact)
# in case it's optional, we need Instance type
referred_to_model = typ.args[1]
is_nullable = helpers.is_optional(referred_to_model)
if is_nullable:
referred_to_model = helpers.make_required(typ.args[1])
if isinstance(referred_to_model, Instance) and referred_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME):
pk_type = extract_explicit_set_type_of_model_primary_key(referred_to_model.type)
if not pk_type:
# extract set type of AutoField
autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField')
pk_type = get_private_descriptor_type(autofield_info, '_pyi_private_set_type',
is_nullable=is_nullable)
primary_key_type = pk_type
expected_types[name + '_id'] = primary_key_type
if field_type:
expected_types[name] = field_type
return expected_types

View File

@@ -0,0 +1,43 @@
from typing import Optional, cast
from mypy.checker import TypeChecker
from mypy.nodes import Expression, StrExpr, TypeInfo
from mypy.plugin import MethodContext
from mypy.types import Instance, Type, TypeType
from mypy_django_plugin import helpers
def get_string_value_from_expr(expr: Expression) -> Optional[str]:
if isinstance(expr, StrExpr):
return expr.value
# TODO: somehow figure out other cases
return None
def determine_model_cls_from_string_for_migrations(ctx: MethodContext) -> Type:
app_label_expr = ctx.args[ctx.callee_arg_names.index('app_label')][0]
app_label = get_string_value_from_expr(app_label_expr)
if app_label is None:
return ctx.default_return_type
if 'model_name' not in ctx.callee_arg_names:
return ctx.default_return_type
model_name_expr_tuple = ctx.args[ctx.callee_arg_names.index('model_name')]
if not model_name_expr_tuple:
return ctx.default_return_type
model_name = get_string_value_from_expr(model_name_expr_tuple[0])
if model_name is None:
return ctx.default_return_type
api = cast(TypeChecker, ctx.api)
model_fullname = helpers.get_model_fullname(app_label, model_name, all_modules=api.modules)
if model_fullname is None:
return ctx.default_return_type
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
all_modules=api.modules)
if model_info is None or not isinstance(model_info, TypeInfo):
return ctx.default_return_type
return TypeType(Instance(model_info, []))

View File

@@ -0,0 +1,265 @@
from abc import ABCMeta, abstractmethod
from typing import Dict, Iterator, List, Optional, Tuple, cast
import dataclasses
from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, CallExpr, ClassDef, Context, Expression, IndexExpr, \
Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext
from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny
from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import iter_over_assignments
@dataclasses.dataclass
class ModelClassInitializer(metaclass=ABCMeta):
api: SemanticAnalyzerPass2
model_classdef: ClassDef
@classmethod
def from_ctx(cls, ctx: ClassDefContext):
return cls(api=cast(SemanticAnalyzerPass2, ctx.api), model_classdef=ctx.cls)
def get_nested_meta_node(self) -> Optional[TypeInfo]:
metaclass_sym = self.model_classdef.info.names.get('Meta')
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
return metaclass_sym.node
return None
def get_meta_attribute(self, name: str) -> Optional[Expression]:
meta_node = self.get_nested_meta_node()
if meta_node is None:
return None
for lvalue, rvalue in iter_over_assignments(meta_node.defn):
if isinstance(lvalue, NameExpr) and lvalue.name == name:
return rvalue
def is_abstract_model(self) -> bool:
is_abstract_expr = self.get_meta_attribute('abstract')
if is_abstract_expr is None:
return False
return self.api.parse_bool(is_abstract_expr)
def add_new_node_to_model_class(self, name: str, typ: Instance) -> None:
var = Var(name=name, type=typ)
var.info = typ.type
var._fullname = self.model_classdef.info.fullname() + '.' + name
var.is_inferred = True
var.is_initialized_in_class = True
self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var)
@abstractmethod
def run(self) -> None:
raise NotImplementedError()
def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]:
for lvalue, rvalue in iter_over_assignments(klass):
if isinstance(rvalue, CallExpr):
yield lvalue, rvalue
def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExpr, CallExpr]]:
for lvalue, rvalue in iter_call_assignments(klass):
if (isinstance(lvalue, NameExpr)
and isinstance(rvalue.callee, MemberExpr)):
if rvalue.callee.fullname in {helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME}:
yield lvalue, rvalue
class SetIdAttrsForRelatedFields(ModelClassInitializer):
def run(self) -> None:
for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef):
node_name = lvalue.name + '_id'
self.add_new_node_to_model_class(name=node_name,
typ=self.api.builtin_type('builtins.int'))
class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):
def run(self) -> None:
meta_node = self.get_nested_meta_node()
if meta_node is None:
return None
meta_node.fallback_to_any = True
def get_model_argument(manager_info: TypeInfo) -> Optional[Instance]:
for base in manager_info.bases:
if base.args:
model_arg = base.args[0]
if isinstance(model_arg, Instance) and model_arg.type.has_base(helpers.MODEL_CLASS_FULLNAME):
return model_arg
return None
class AddDefaultObjectsManager(ModelClassInitializer):
def add_new_manager(self, name: str, manager_type: Optional[Instance]) -> None:
if manager_type is None:
return None
self.add_new_node_to_model_class(name, manager_type)
def add_private_default_manager(self, manager_type: Optional[Instance]) -> None:
if manager_type is None:
return None
self.add_new_node_to_model_class('_default_manager', manager_type)
def get_existing_managers(self) -> List[Tuple[str, TypeInfo]]:
managers = []
for base in self.model_classdef.info.mro:
for name_expr, member_expr in iter_call_assignments(base.defn):
manager_name = name_expr.name
callee_expr = member_expr.callee
if isinstance(callee_expr, IndexExpr):
callee_expr = callee_expr.analyzed.expr
if isinstance(callee_expr, (MemberExpr, NameExpr)) \
and isinstance(callee_expr.node, TypeInfo) \
and callee_expr.node.has_base(helpers.BASE_MANAGER_CLASS_FULLNAME):
managers.append((manager_name, callee_expr.node))
return managers
def run(self) -> None:
existing_managers = self.get_existing_managers()
if existing_managers:
first_manager_type = None
for manager_name, manager_type_info in existing_managers:
manager_type = Instance(manager_type_info, args=[Instance(self.model_classdef.info, [])])
self.add_new_manager(name=manager_name, manager_type=manager_type)
if first_manager_type is None:
first_manager_type = manager_type
else:
if self.is_abstract_model():
# abstract models do not need 'objects' queryset
return None
first_manager_type = self.api.named_type_or_none(helpers.MANAGER_CLASS_FULLNAME,
args=[Instance(self.model_classdef.info, [])])
self.add_new_manager('objects', manager_type=first_manager_type)
if self.is_abstract_model():
return None
default_manager_name_expr = self.get_meta_attribute('default_manager_name')
if isinstance(default_manager_name_expr, StrExpr):
self.add_private_default_manager(self.model_classdef.info.get(default_manager_name_expr.value).type)
else:
self.add_private_default_manager(first_manager_type)
class AddIdAttributeIfPrimaryKeyTrueIsNotSet(ModelClassInitializer):
def run(self) -> None:
if self.is_abstract_model():
# no need for .id attr
return None
for _, rvalue in iter_call_assignments(self.model_classdef):
if ('primary_key' in rvalue.arg_names
and self.api.parse_bool(rvalue.args[rvalue.arg_names.index('primary_key')])):
break
else:
self.add_new_node_to_model_class('id', self.api.builtin_type('builtins.object'))
class AddRelatedManagers(ModelClassInitializer):
def run(self) -> None:
for module_name, module_file in self.api.modules.items():
for defn in iter_over_classdefs(module_file):
for lvalue, rvalue in iter_call_assignments(defn):
if is_related_field(rvalue, module_file):
try:
ref_to_fullname = extract_ref_to_fullname(rvalue,
module_file=module_file,
all_modules=self.api.modules)
except helpers.SelfReference:
ref_to_fullname = defn.fullname
except helpers.InvalidModelString as exc:
self.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}',
Context(line=rvalue.line))
return None
if self.model_classdef.fullname == ref_to_fullname:
related_manager_name = defn.name.lower() + '_set'
if 'related_name' in rvalue.arg_names:
related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')]
if not isinstance(related_name_expr, StrExpr):
return None
related_manager_name = related_name_expr.value
typ = get_related_field_type(rvalue, self.api, defn.info)
if typ is None:
return None
self.add_new_node_to_model_class(related_manager_name, typ)
def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]:
for defn in module_file.defs:
if isinstance(defn, ClassDef):
yield defn
def get_related_field_type(rvalue: CallExpr, api: SemanticAnalyzerPass2,
related_model_typ: TypeInfo) -> Optional[Instance]:
if rvalue.callee.name in {'ForeignKey', 'ManyToManyField'}:
return api.named_type_or_none(helpers.RELATED_MANAGER_CLASS_FULLNAME,
args=[Instance(related_model_typ, [])])
else:
return Instance(related_model_typ, [])
def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool:
if isinstance(expr.callee, MemberExpr) and isinstance(expr.callee.expr, NameExpr):
module = module_file.names.get(expr.callee.expr.name)
if module \
and module.fullname == 'django.db.models' \
and expr.callee.name in {'ForeignKey',
'OneToOneField',
'ManyToManyField'}:
return True
return False
def extract_ref_to_fullname(rvalue_expr: CallExpr,
module_file: MypyFile, all_modules: Dict[str, MypyFile]) -> Optional[str]:
if 'to' in rvalue_expr.arg_names:
to_expr = rvalue_expr.args[rvalue_expr.arg_names.index('to')]
else:
to_expr = rvalue_expr.args[0]
if isinstance(to_expr, NameExpr):
return module_file.names[to_expr.name].fullname
elif isinstance(to_expr, StrExpr):
typ_fullname = helpers.get_model_fullname_from_string(to_expr.value, all_modules)
if typ_fullname is None:
return None
return typ_fullname
return None
def add_dummy_init_method(ctx: ClassDefContext) -> None:
any = AnyType(TypeOfAny.special_form)
pos_arg = Argument(variable=Var('args', any),
type_annotation=any, initializer=None, kind=ARG_STAR)
kw_arg = Argument(variable=Var('kwargs', any),
type_annotation=any, initializer=None, kind=ARG_STAR2)
add_method(ctx, '__init__', [pos_arg, kw_arg], NoneTyp())
# mark as model class
ctx.cls.info.metadata.setdefault('django', {})['generated_init'] = True
def process_model_class(ctx: ClassDefContext) -> None:
initializers = [
InjectAnyAsBaseForNestedMeta,
AddDefaultObjectsManager,
AddIdAttributeIfPrimaryKeyTrueIsNotSet,
SetIdAttrsForRelatedFields,
AddRelatedManagers,
]
for initializer_cls in initializers:
initializer_cls.from_ctx(ctx).run()
add_dummy_init_method(ctx)
# allow unspecified attributes for now
ctx.cls.info.fallback_to_any = True

View File

@@ -0,0 +1,94 @@
from typing import Iterable, List, Optional, cast
from mypy.nodes import ClassDef, Context, ImportAll, MypyFile, SymbolNode, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import AnyType, Instance, NoneTyp, Type, TypeOfAny, UnionType
from mypy.util import correct_relative_import
def get_error_context(node: SymbolNode) -> Context:
context = Context()
context.set_line(node)
return context
def filter_out_nones(typ: UnionType) -> List[Type]:
return [item for item in typ.items if not isinstance(item, NoneTyp)]
def make_sym_copy_of_setting(sym: SymbolTableNode) -> Optional[SymbolTableNode]:
if isinstance(sym.type, Instance):
copied = sym.copy()
copied.node.info = sym.type.type
return copied
elif isinstance(sym.type, UnionType):
instances = filter_out_nones(sym.type)
if len(instances) > 1:
# plain unions not supported yet
return None
typ = instances[0]
if isinstance(typ, Instance):
copied = sym.copy()
copied.node.info = typ.type
return copied
return None
else:
return None
def get_settings_metadata(lazy_settings_info: TypeInfo):
return lazy_settings_info.metadata.setdefault('django', {}).setdefault('settings', {})
def load_settings_from_names(settings_classdef: ClassDef,
modules: Iterable[MypyFile],
api: SemanticAnalyzerPass2) -> None:
settings_metadata = get_settings_metadata(settings_classdef.info)
for module in modules:
for name, sym in module.names.items():
if name.isupper() and isinstance(sym.node, Var):
if sym.type is not None:
copied = make_sym_copy_of_setting(sym)
if copied is None:
continue
settings_classdef.info.names[name] = copied
else:
var = Var(name, AnyType(TypeOfAny.unannotated))
var.info = api.named_type('__builtins__.object').type
settings_classdef.info.names[name] = SymbolTableNode(sym.kind, var)
settings_metadata[name] = module.fullname()
def get_import_star_modules(api: SemanticAnalyzerPass2, module: MypyFile) -> List[str]:
import_star_modules = []
for module_import in module.imports:
# relative import * are not resolved by mypy
if isinstance(module_import, ImportAll) and module_import.relative:
absolute_import_path, correct = correct_relative_import(module.fullname(), module_import.relative, module_import.id,
is_cur_package_init_file=False)
if not correct:
return []
for path in [absolute_import_path] + get_import_star_modules(api, module=api.modules.get(absolute_import_path)):
if path not in import_star_modules:
import_star_modules.append(path)
return import_star_modules
class AddSettingValuesToDjangoConfObject:
def __init__(self, settings_modules: List[str], ignore_missing_settings: bool):
self.settings_modules = settings_modules
self.ignore_missing_settings = ignore_missing_settings
def __call__(self, ctx: ClassDefContext) -> None:
api = cast(SemanticAnalyzerPass2, ctx.api)
for module_name in self.settings_modules:
module = api.modules[module_name]
star_deps = [api.modules[star_dep]
for star_dep in reversed(get_import_star_modules(api, module))]
load_settings_from_names(ctx.cls, modules=star_deps + [module], api=api)
if self.ignore_missing_settings:
ctx.cls.info.fallback_to_any = True