split helpers into smaller files

This commit is contained in:
Maxim Kurnikov
2019-07-12 15:09:51 +03:00
parent a9c1bcbbc6
commit 9c5a6be9a7
18 changed files with 199 additions and 200 deletions

View File

View File

@@ -0,0 +1,29 @@
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
FIELD_FULLNAME = 'django.db.models.fields.Field'
CHAR_FIELD_FULLNAME = 'django.db.models.fields.CharField'
ARRAY_FIELD_FULLNAME = 'django.contrib.postgres.fields.array.ArrayField'
AUTO_FIELD_FULLNAME = 'django.db.models.fields.AutoField'
GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey'
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField'
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
BASE_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.BaseManager'
MANAGER_CLASS_FULLNAME = 'django.db.models.manager.Manager'
RELATED_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.RelatedManager'
BASEFORM_CLASS_FULLNAME = 'django.forms.forms.BaseForm'
FORM_CLASS_FULLNAME = 'django.forms.forms.Form'
MODELFORM_CLASS_FULLNAME = 'django.forms.models.ModelForm'
FORM_MIXIN_CLASS_FULLNAME = 'django.views.generic.edit.FormMixin'
MANAGER_CLASSES = {
MANAGER_CLASS_FULLNAME,
RELATED_MANAGER_CLASS_FULLNAME,
BASE_MANAGER_CLASS_FULLNAME,
QUERYSET_CLASS_FULLNAME
}

View File

@@ -0,0 +1,396 @@
import typing
from collections import OrderedDict
from typing import Dict, Optional, cast
from mypy.mro import calculate_mro
from mypy.nodes import (
GDEF, MDEF, AssignmentStmt, Block, CallExpr, ClassDef, Expression, ImportedName, Lvalue, MypyFile, NameExpr,
SymbolNode, SymbolTable, SymbolTableNode, TypeInfo, Var,
)
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext
from mypy.types import (
AnyType, Instance, NoneTyp, TupleType, Type, TypedDictType, TypeOfAny, TypeVarType, UnionType,
)
from mypy_django_plugin.lib import metadata, fullnames
if typing.TYPE_CHECKING:
from mypy.checker import TypeChecker
def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]:
models_module = '.'.join([app_name, 'models'])
return all_modules.get(models_module)
def get_model_fullname(app_name: str, model_name: str,
all_modules: Dict[str, MypyFile]) -> Optional[str]:
models_file = get_models_file(app_name, all_modules)
if models_file is None:
# not imported so far, not supported
return None
sym = models_file.names.get(model_name)
if not sym:
return None
if isinstance(sym.node, TypeInfo):
return sym.node.fullname()
elif isinstance(sym.node, ImportedName):
return sym.node.target_fullname
else:
return None
class SameFileModel(Exception):
def __init__(self, model_cls_name: str):
self.model_cls_name = model_cls_name
class SelfReference(ValueError):
pass
def get_model_fullname_from_string(model_string: str,
all_modules: Dict[str, MypyFile]) -> Optional[str]:
if model_string == 'self':
raise SelfReference()
if '.' not in model_string:
raise SameFileModel(model_string)
app_name, model_name = model_string.split('.')
return get_model_fullname(app_name, model_name, all_modules)
def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]:
if '.' not in name:
return None
module, cls_name = name.rsplit('.', 1)
module_file = all_modules.get(module)
if module_file is None:
return None
sym = module_file.names.get(cls_name)
if sym is None:
return None
return sym.node
def parse_bool(expr: Expression) -> Optional[bool]:
if isinstance(expr, NameExpr):
if expr.fullname == 'builtins.True':
return True
if expr.fullname == 'builtins.False':
return False
return None
def reparametrize_instance(instance: Instance, new_args: typing.List[Type]) -> Instance:
return Instance(instance.type, args=new_args,
line=instance.line, column=instance.column)
def fill_typevars_with_any(instance: Instance) -> Instance:
return reparametrize_instance(instance, [AnyType(TypeOfAny.unannotated)])
def extract_typevar_value(tp: Instance, typevar_name: str) -> Type:
if typevar_name in {'_T', '_T_co'}:
if '_T' in tp.type.type_vars:
return tp.args[tp.type.type_vars.index('_T')]
if '_T_co' in tp.type.type_vars:
return tp.args[tp.type.type_vars.index('_T_co')]
return tp.args[tp.type.type_vars.index(typevar_name)]
def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance:
typevar_values: typing.List[Type] = []
for typevar_arg in type_to_fill.args:
if isinstance(typevar_arg, TypeVarType):
typevar_values.append(extract_typevar_value(tp, typevar_arg.name))
return Instance(type_to_fill.type, typevar_values)
def get_argument_by_name(ctx: typing.Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]:
"""Return the expression for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
args = ctx.args[idx]
if len(args) != 1:
# Either an error or no value passed.
return None
return args[0]
def get_argument_type_by_name(ctx: typing.Union[FunctionContext, MethodContext], name: str) -> Optional[Type]:
"""Return the type for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
arg_types = ctx.arg_types[idx]
if len(arg_types) != 1:
# Either an error or no value passed.
return None
return arg_types[0]
def get_setting_expr(api: 'TypeChecker', setting_name: str) -> Optional[Expression]:
try:
settings_sym = api.modules['django.conf'].names['settings']
except KeyError:
return None
settings_type: TypeInfo = settings_sym.type.type
auth_user_model_sym = settings_type.get(setting_name)
if not auth_user_model_sym:
return None
module, _, name = auth_user_model_sym.fullname.rpartition('.')
if module not in api.modules:
return None
module_file = api.modules.get(module)
for name_expr, value_expr in iter_over_assignments(module_file):
if isinstance(name_expr, NameExpr) and name_expr.name == setting_name:
return value_expr
return None
def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]
) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
if isinstance(class_or_module, ClassDef):
statements = class_or_module.defs.body
else:
statements = class_or_module.defs
for stmt in statements:
if not isinstance(stmt, AssignmentStmt):
continue
if len(stmt.lvalues) > 1:
# not supported yet
continue
yield stmt.lvalues[0], stmt.rvalue
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
""" Extract __set__ value of a field. """
if tp.type.has_base(fullnames.FIELD_FULLNAME):
return tp.args[0]
# GenericForeignKey
if tp.type.has_base(fullnames.GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
def extract_field_getter_type(tp: Type) -> Optional[Type]:
""" Extract return type of __get__ of subclass of Field"""
if not isinstance(tp, Instance):
return None
if tp.type.has_base(fullnames.FIELD_FULLNAME):
return tp.args[1]
# GenericForeignKey
if tp.type.has_base(fullnames.GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[Type]:
"""
If field with primary_key=True is set on the model, extract its __set__ type.
"""
for field_name, props in metadata.get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
return extract_field_setter_type(model.names[field_name].type)
return None
def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]:
for field_name, props in metadata.get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
return extract_field_getter_type(model.names[field_name].type)
return None
def make_optional(typ: Type):
return UnionType.make_union([typ, NoneTyp()])
def make_required(typ: Type) -> Type:
if not isinstance(typ, UnionType):
return typ
items = [item for item in typ.items if not isinstance(item, NoneTyp)]
# will reduce to Instance, if only one item
return UnionType.make_union(items)
def is_optional(typ: Type) -> bool:
if not isinstance(typ, UnionType):
return False
return any([isinstance(item, NoneTyp) for item in typ.items])
def has_any_of_bases(info: TypeInfo, bases: typing.Sequence[str]) -> bool:
for base_fullname in bases:
if info.has_base(base_fullname):
return True
return False
def is_none_expr(expr: Expression) -> bool:
return isinstance(expr, NameExpr) and expr.fullname == 'builtins.None'
def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]:
metaclass_sym = info.names.get('Meta')
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
return metaclass_sym.node
return None
def get_assigned_value_for_class(type_info: TypeInfo, name: str) -> Optional[Expression]:
for lvalue, rvalue in iter_over_assignments(type_info.defn):
if isinstance(lvalue, NameExpr) and lvalue.name == name:
return rvalue
return None
def is_field_nullable(model: TypeInfo, field_name: str) -> bool:
return metadata.get_fields_metadata(model).get(field_name, {}).get('null', False)
def is_foreign_key_like(t: Type) -> bool:
if not isinstance(t, Instance):
return False
return has_any_of_bases(t.type, (fullnames.FOREIGN_KEY_FULLNAME, fullnames.ONETOONE_FIELD_FULLNAME))
def build_class_with_annotated_fields(api: 'TypeChecker', base: Type, fields: 'OrderedDict[str, Type]',
name: str) -> Instance:
"""Build an Instance with `name` that contains the specified `fields` as attributes and extends `base`."""
# Credit: This code is largely copied/modified from TypeChecker.intersect_instance_callable and
# NamedTupleAnalyzer.build_namedtuple_typeinfo
from mypy.checker import gen_unique_name
cur_module = cast(MypyFile, api.scope.stack[0])
gen_name = gen_unique_name(name, cur_module.names)
cdef = ClassDef(name, Block([]))
cdef.fullname = cur_module.fullname() + '.' + gen_name
info = TypeInfo(SymbolTable(), cdef, cur_module.fullname())
cdef.info = info
info.bases = [base]
def add_field(var: Var, is_initialized_in_class: bool = False,
is_property: bool = False) -> None:
var.info = info
var.is_initialized_in_class = is_initialized_in_class
var.is_property = is_property
var._fullname = '%s.%s' % (info.fullname(), var.name())
info.names[var.name()] = SymbolTableNode(MDEF, var)
vars = [Var(item, typ) for item, typ in fields.items()]
for var in vars:
add_field(var, is_property=True)
calculate_mro(info)
info.calculate_metaclass_type()
cur_module.names[gen_name] = SymbolTableNode(GDEF, info, plugin_generated=True)
return Instance(info, [])
def make_named_tuple(api: 'TypeChecker', fields: 'OrderedDict[str, Type]', name: str) -> Type:
if not fields:
# No fields specified, so fallback to a subclass of NamedTuple that allows
# __getattr__ / __setattr__ for any attribute name.
fallback = api.named_generic_type('django._NamedTupleAnyAttr', [])
else:
fallback = build_class_with_annotated_fields(
api=api,
base=api.named_generic_type('typing.NamedTuple', []),
fields=fields,
name=name
)
return TupleType(list(fields.values()), fallback=fallback)
def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, Type]',
required_keys: typing.Set[str]) -> TypedDictType:
object_type = api.named_generic_type('mypy_extensions._TypedDict', [])
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
return typed_dict_type
def make_tuple(api: 'TypeChecker', fields: typing.List[Type]) -> TupleType:
implicit_any = AnyType(TypeOfAny.special_form)
fallback = api.named_generic_type('builtins.tuple', [implicit_any])
return TupleType(fields, fallback=fallback)
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> Type:
node = type_info.get(private_field_name).node
if isinstance(node, Var):
descriptor_type = node.type
if is_nullable:
descriptor_type = make_optional(descriptor_type)
return descriptor_type
return AnyType(TypeOfAny.unannotated)
def iter_over_classdefs(module_file: MypyFile) -> typing.Iterator[ClassDef]:
for defn in module_file.defs:
if isinstance(defn, ClassDef):
yield defn
def iter_call_assignments(klass: ClassDef) -> typing.Iterator[typing.Tuple[Lvalue, CallExpr]]:
for lvalue, rvalue in iter_over_assignments(klass):
if isinstance(rvalue, CallExpr):
yield lvalue, rvalue
def get_related_manager_type_from_metadata(model_info: TypeInfo, related_manager_name: str,
api: CheckerPluginInterface) -> Optional[Instance]:
related_manager_metadata = metadata.get_related_managers_metadata(model_info)
if not related_manager_metadata:
return None
if related_manager_name not in related_manager_metadata:
return None
manager_class_name = related_manager_metadata[related_manager_name]['manager']
of = related_manager_metadata[related_manager_name]['of']
of_types = []
for of_type_name in of:
if of_type_name == 'any':
of_types.append(AnyType(TypeOfAny.implementation_artifact))
else:
try:
of_type = api.named_generic_type(of_type_name, [])
except AssertionError:
# Internal error: attempted lookup of unknown name
of_type = AnyType(TypeOfAny.implementation_artifact)
of_types.append(of_type)
return api.named_generic_type(manager_class_name, of_types)
def get_primary_key_field_name(model_info: TypeInfo) -> Optional[str]:
for base in model_info.mro:
fields = metadata.get_fields_metadata(base)
for field_name, field_props in fields.items():
is_primary_key = field_props.get('primary_key', False)
if is_primary_key:
return field_name
return None

View File

@@ -0,0 +1,159 @@
from typing import List, Union
import dataclasses
from mypy.nodes import TypeInfo
from mypy.plugin import CheckerPluginInterface
from mypy.types import Instance, Type
from mypy_django_plugin.lib import metadata, helpers
@dataclasses.dataclass
class RelatedModelNode:
typ: Instance
is_nullable: bool
@dataclasses.dataclass
class FieldNode:
typ: Type
LookupNode = Union[RelatedModelNode, FieldNode]
class LookupException(Exception):
pass
def resolve_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, lookup: str) -> List[LookupNode]:
"""Resolve a lookup str to a list of LookupNodes.
Each node represents a part of the lookup (separated by "__"), in order.
Each node is the Model or Field that was resolved.
Raises LookupException if there were any issues resolving the lookup.
"""
lookup_parts = lookup.split("__")
nodes = []
while lookup_parts:
lookup_part = lookup_parts.pop(0)
if not nodes:
current_node = None
else:
current_node = nodes[-1]
if current_node is None:
new_node = resolve_model_lookup(api, model_type_info, lookup_part)
elif isinstance(current_node, RelatedModelNode):
new_node = resolve_model_lookup(api, current_node.typ.type, lookup_part)
elif isinstance(current_node, FieldNode):
raise LookupException(f"Field lookups not yet supported for lookup {lookup}")
else:
raise LookupException(f"Unsupported node type: {type(current_node)}")
nodes.append(new_node)
return nodes
def resolve_model_pk_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo) -> LookupNode:
# Primary keys are special-cased
primary_key_type = helpers.extract_primary_key_type_for_get(model_type_info)
if primary_key_type:
return FieldNode(primary_key_type)
else:
# No PK, use the get type for AutoField as PK type.
autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField')
pk_type = helpers.get_private_descriptor_type(autofield_info, '_pyi_private_get_type',
is_nullable=False)
return FieldNode(pk_type)
def resolve_model_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo,
lookup: str) -> LookupNode:
"""Resolve a lookup on the given model."""
if lookup == 'pk':
return resolve_model_pk_lookup(api, model_type_info)
field_name = get_actual_field_name_for_lookup_field(lookup, model_type_info)
field_node = model_type_info.get(field_name)
if not field_node:
raise LookupException(
f'When resolving lookup "{lookup}", field "{field_name}" was not found in model {model_type_info.name()}')
if field_name.endswith('_id'):
field_name_without_id = field_name.rstrip('_id')
foreign_key_field = model_type_info.get(field_name_without_id)
if foreign_key_field is not None and helpers.is_foreign_key_like(foreign_key_field.type):
# Hack: If field ends with '_id' and there is a model field without the '_id' suffix, then use that field.
field_node = foreign_key_field
field_name = field_name_without_id
field_node_type = field_node.type
if field_node_type is None or not isinstance(field_node_type, Instance):
raise LookupException(
f'When resolving lookup "{lookup}", could not determine type for {model_type_info.name()}.{field_name}')
if field_node_type.type.fullname() == 'builtins.object':
# could be related manager
related_manager_type = helpers.get_related_manager_type_from_metadata(model_type_info, field_name, api)
if related_manager_type:
model_arg = related_manager_type.args[0]
if not isinstance(model_arg, Instance):
raise LookupException(
f'When resolving lookup "{lookup}", could not determine type '
f'for {model_type_info.name()}.{field_name}')
return RelatedModelNode(typ=model_arg, is_nullable=False)
if helpers.is_foreign_key_like(field_node_type):
field_type = helpers.extract_field_getter_type(field_node_type)
is_nullable = helpers.is_optional(field_type)
if is_nullable:
# type is always non-optional
field_type = helpers.make_required(field_type)
if isinstance(field_type, Instance):
return RelatedModelNode(typ=field_type, is_nullable=is_nullable)
else:
raise LookupException(f"Not an instance for field {field_type} lookup {lookup}")
field_type = helpers.extract_field_getter_type(field_node_type)
if field_type:
return FieldNode(typ=field_type)
# Not a Field
if field_name == 'id':
# If no 'id' field was found, use an int
return FieldNode(api.named_generic_type('builtins.int', []))
raise LookupException(
f'When resolving lookup {lookup!r}, could not determine type for {model_type_info.name()}.{field_name}')
def get_actual_field_name_for_lookup_field(lookup: str, model_type_info: TypeInfo) -> str:
"""Attempt to find out the real field name if this lookup is a related_query_name (for reverse relations).
If it's not, return the original lookup.
"""
lookups_metadata = metadata.get_lookups_metadata(model_type_info)
lookup_metadata = lookups_metadata.get(lookup)
if lookup_metadata is None:
# If not found on current model, look in all bases for their lookup metadata
for base in model_type_info.mro:
lookups_metadata = metadata.get_lookups_metadata(base)
lookup_metadata = lookups_metadata.get(lookup)
if lookup_metadata:
break
if not lookup_metadata:
lookup_metadata = {}
related_name = lookup_metadata.get('related_query_name_target', None)
if related_name:
# If the lookup is a related lookup, then look at the field specified by related_name.
# This is to support if related_query_name is set and differs from.
field_name = related_name
else:
field_name = lookup
return field_name

View File

@@ -0,0 +1,23 @@
from typing import Any, Dict, List
from mypy.nodes import TypeInfo
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault('django', {})
def get_related_field_primary_key_names(base_model: TypeInfo) -> List[str]:
return get_django_metadata(base_model).setdefault('related_field_primary_keys', [])
def get_fields_metadata(model: TypeInfo) -> Dict[str, Any]:
return get_django_metadata(model).setdefault('fields', {})
def get_lookups_metadata(model: TypeInfo) -> Dict[str, Any]:
return get_django_metadata(model).setdefault('lookups', {})
def get_related_managers_metadata(model: TypeInfo) -> Dict[str, Any]:
return get_django_metadata(model).setdefault('related_managers', {})