diff --git a/mypy_django_plugin/__init__.py b/mypy_django_plugin/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mypy_django_plugin/lib/__init__.py b/mypy_django_plugin/lib/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mypy_django_plugin/lib/config.py b/mypy_django_plugin/lib/config.py deleted file mode 100644 index d632174..0000000 --- a/mypy_django_plugin/lib/config.py +++ /dev/null @@ -1,54 +0,0 @@ -import os -from configparser import ConfigParser -from typing import Dict, List, Optional - -import dataclasses -from dataclasses import dataclass -from pytest_mypy.utils import temp_environ - - -@dataclass -class Config: - django_settings_module: Optional[str] = None - installed_apps: List[str] = dataclasses.field(default_factory=list) - - ignore_missing_settings: bool = False - ignore_missing_model_attributes: bool = False - - @classmethod - def from_config_file(cls, fpath: str) -> 'Config': - ini_config = ConfigParser() - ini_config.read(fpath) - if not ini_config.has_section('mypy_django_plugin'): - raise ValueError('Invalid config file: no [mypy_django_plugin] section') - - django_settings = ini_config.get('mypy_django_plugin', 'django_settings', - fallback=None) - if django_settings: - django_settings = django_settings.strip() - - return Config(django_settings_module=django_settings, - ignore_missing_settings=bool(ini_config.get('mypy_django_plugin', - 'ignore_missing_settings', - fallback=False)), - ignore_missing_model_attributes=bool(ini_config.get('mypy_django_plugin', - 'ignore_missing_model_attributes', - fallback=False))) - - -def extract_app_model_aliases(settings_module: str) -> Dict[str, str]: - with temp_environ(): - os.environ['DJANGO_SETTINGS_MODULE'] = settings_module - import django - django.setup() - - app_model_mapping: Dict[str, str] = {} - - from django.apps import apps - - for name, app_config in apps.app_configs.items(): - app_label = app_config.label - for model_name, model_class in app_config.models.items(): - app_model_mapping[app_label + '.' + model_class.__name__] = model_class.__module__ + '.' + model_class.__name__ - - return app_model_mapping diff --git a/mypy_django_plugin/lib/fullnames.py b/mypy_django_plugin/lib/fullnames.py deleted file mode 100644 index 0821122..0000000 --- a/mypy_django_plugin/lib/fullnames.py +++ /dev/null @@ -1,35 +0,0 @@ - -MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' -FIELD_FULLNAME = 'django.db.models.fields.Field' -CHAR_FIELD_FULLNAME = 'django.db.models.fields.CharField' -ARRAY_FIELD_FULLNAME = 'django.contrib.postgres.fields.array.ArrayField' -AUTO_FIELD_FULLNAME = 'django.db.models.fields.AutoField' -GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey' -FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey' -ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField' -MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField' -DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject' - -QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet' -BASE_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.BaseManager' -MANAGER_CLASS_FULLNAME = 'django.db.models.manager.Manager' -RELATED_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.RelatedManager' - -BASEFORM_CLASS_FULLNAME = 'django.forms.forms.BaseForm' -FORM_CLASS_FULLNAME = 'django.forms.forms.Form' -MODELFORM_CLASS_FULLNAME = 'django.forms.models.ModelForm' - -FORM_MIXIN_CLASS_FULLNAME = 'django.views.generic.edit.FormMixin' - -MANAGER_CLASSES = { - MANAGER_CLASS_FULLNAME, - RELATED_MANAGER_CLASS_FULLNAME, - BASE_MANAGER_CLASS_FULLNAME, - QUERYSET_CLASS_FULLNAME -} - -RELATED_FIELDS_CLASSES = { - FOREIGN_KEY_FULLNAME, - ONETOONE_FIELD_FULLNAME, - MANYTOMANY_FIELD_FULLNAME -} \ No newline at end of file diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py deleted file mode 100644 index e4ba50e..0000000 --- a/mypy_django_plugin/lib/helpers.py +++ /dev/null @@ -1,452 +0,0 @@ -from collections import OrderedDict -from typing import Dict, Iterator, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast - -from mypy.mro import calculate_mro -from mypy.nodes import (AssignmentStmt, Block, CallExpr, ClassDef, Expression, FakeInfo, GDEF, ImportedName, Lvalue, MDEF, - MemberExpr, MypyFile, NameExpr, SymbolNode, SymbolTable, SymbolTableNode, TypeInfo, Var) -from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext -from mypy.types import (AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypeVarType, TypedDictType, - UnionType) - -from mypy_django_plugin.lib import fullnames, metadata - -if TYPE_CHECKING: - from mypy.checker import TypeChecker - - -def get_models_file(app_name: str, all_modules: Dict[str, MypyFile]) -> Optional[MypyFile]: - models_module = '.'.join([app_name, 'models']) - return all_modules.get(models_module) - - -def get_model_fullname(app_name: str, model_name: str, - all_modules: Dict[str, MypyFile]) -> Optional[str]: - models_file = get_models_file(app_name, all_modules) - if models_file is None: - # not imported so far, not supported - return None - sym = models_file.names.get(model_name) - if not sym: - return None - - if isinstance(sym.node, TypeInfo): - return sym.node.fullname() - elif isinstance(sym.node, ImportedName): - return sym.node.target_fullname - else: - return None - - -class SameFileModel(Exception): - def __init__(self, model_cls_name: str): - self.model_cls_name = model_cls_name - - -class SelfReference(ValueError): - pass - - -def get_model_fullname_from_string(model_string: str, - all_modules: Dict[str, MypyFile]) -> Optional[str]: - if model_string == 'self': - raise SelfReference() - - if '.' not in model_string: - raise SameFileModel(model_string) - - app_name, model_name = model_string.split('.') - return get_model_fullname(app_name, model_name, all_modules) - - -def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]: - if '.' not in name: - return None - module, cls_name = name.rsplit('.', 1) - - module_file = all_modules.get(module) - if module_file is None: - return None - sym = module_file.names.get(cls_name) - if sym is None: - return None - return sym.node - - -def parse_bool(expr: Expression) -> Optional[bool]: - if isinstance(expr, NameExpr): - if expr.fullname == 'builtins.True': - return True - if expr.fullname == 'builtins.False': - return False - return None - - -def reparametrize_instance(instance: Instance, new_args: List[MypyType]) -> Instance: - return Instance(instance.type, args=new_args, - line=instance.line, column=instance.column) - - -def fill_typevars_with_any(instance: Instance) -> Instance: - return reparametrize_instance(instance, [AnyType(TypeOfAny.unannotated)]) - - -def extract_typevar_value(tp: Instance, typevar_name: str) -> MypyType: - if typevar_name in {'_T', '_T_co'}: - if '_T' in tp.type.type_vars: - return tp.args[tp.type.type_vars.index('_T')] - if '_T_co' in tp.type.type_vars: - return tp.args[tp.type.type_vars.index('_T_co')] - return tp.args[tp.type.type_vars.index(typevar_name)] - - -def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance: - typevar_values: List[MypyType] = [] - for typevar_arg in type_to_fill.args: - if isinstance(typevar_arg, TypeVarType): - typevar_values.append(extract_typevar_value(tp, typevar_arg.name)) - return Instance(type_to_fill.type, typevar_values) - - -def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]: - """ - Return the expression for the specific argument. - This helper should only be used with non-star arguments. - """ - if name not in ctx.callee_arg_names: - return None - idx = ctx.callee_arg_names.index(name) - args = ctx.args[idx] - if len(args) != 1: - # Either an error or no value passed. - return None - return args[0] - - -def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]: - """Return the type for the specific argument. - - This helper should only be used with non-star arguments. - """ - if name not in ctx.callee_arg_names: - return None - idx = ctx.callee_arg_names.index(name) - arg_types = ctx.arg_types[idx] - if len(arg_types) != 1: - # Either an error or no value passed. - return None - return arg_types[0] - - -def get_setting_expr(api: 'TypeChecker', setting_name: str) -> Optional[Expression]: - try: - settings_sym = api.modules['django.conf'].names['settings'] - except KeyError: - return None - - settings_type: TypeInfo = settings_sym.type.type - auth_user_model_sym = settings_type.get(setting_name) - if not auth_user_model_sym: - return None - - module, _, name = auth_user_model_sym.fullname.rpartition('.') - if module not in api.modules: - return None - - module_file = api.modules.get(module) - for name_expr, value_expr in iter_over_module_level_assignments(module_file): - if isinstance(name_expr, NameExpr) and name_expr.name == setting_name: - return value_expr - return None - - -def iter_over_class_level_assignments(klass: ClassDef) -> Iterator[Tuple[str, Expression]]: - for stmt in klass.defs.body: - if not isinstance(stmt, AssignmentStmt): - continue - if len(stmt.lvalues) > 1: - # skip multiple assignments - continue - lvalue = stmt.lvalues[0] - if isinstance(lvalue, NameExpr): - yield lvalue.name, stmt.rvalue - - -def iter_over_module_level_assignments(module: MypyFile) -> Iterator[Tuple[str, Expression]]: - for stmt in module.defs: - if not isinstance(stmt, AssignmentStmt): - continue - if len(stmt.lvalues) > 1: - # skip multiple assignments - continue - lvalue = stmt.lvalues[0] - if isinstance(lvalue, NameExpr): - yield lvalue.name, stmt.rvalue - - -def iter_over_assignments_in_class(class_or_module: Union[ClassDef, MypyFile] - ) -> Iterator[Tuple[str, Expression]]: - if isinstance(class_or_module, ClassDef): - statements = class_or_module.defs.body - else: - statements = class_or_module.defs - - for stmt in statements: - if not isinstance(stmt, AssignmentStmt): - continue - if len(stmt.lvalues) > 1: - # not supported yet - continue - lvalue = stmt.lvalues[0] - if isinstance(lvalue, NameExpr): - yield lvalue.name, stmt.rvalue - - -def extract_field_setter_type(tp: Instance) -> Optional[MypyType]: - """ Extract __set__ value of a field. """ - if tp.type.has_base(fullnames.FIELD_FULLNAME): - return tp.args[0] - # GenericForeignKey - if tp.type.has_base(fullnames.GENERIC_FOREIGN_KEY_FULLNAME): - return AnyType(TypeOfAny.special_form) - return None - - -def extract_field_getter_type(tp: MypyType) -> Optional[MypyType]: - """ Extract return type of __get__ of subclass of Field""" - if not isinstance(tp, Instance): - return None - if tp.type.has_base(fullnames.FIELD_FULLNAME): - return tp.args[1] - # GenericForeignKey - if tp.type.has_base(fullnames.GENERIC_FOREIGN_KEY_FULLNAME): - return AnyType(TypeOfAny.special_form) - return None - - -def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[MypyType]: - """ - If field with primary_key=True is set on the model, extract its __set__ type. - """ - for field_name, props in metadata.get_fields_metadata(model).items(): - is_primary_key = props.get('primary_key', False) - if is_primary_key: - return extract_field_setter_type(model.names[field_name].type) - return None - - -def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[MypyType]: - for field_name, props in metadata.get_fields_metadata(model).items(): - is_primary_key = props.get('primary_key', False) - if is_primary_key: - return extract_field_getter_type(model.names[field_name].type) - return None - - -def make_optional(typ: MypyType) -> MypyType: - return UnionType.make_union([typ, NoneTyp()]) - - -def make_required(typ: MypyType) -> MypyType: - if not isinstance(typ, UnionType): - return typ - items = [item for item in typ.items if not isinstance(item, NoneTyp)] - # will reduce to Instance, if only one item - return UnionType.make_union(items) - - -def is_optional(typ: MypyType) -> bool: - if not isinstance(typ, UnionType): - return False - - return any([isinstance(item, NoneTyp) for item in typ.items]) - - -def has_any_of_bases(info: TypeInfo, bases: Set[str]) -> bool: - for base_fullname in bases: - if info.has_base(base_fullname): - return True - return False - - -def is_none_expr(expr: Expression) -> bool: - return isinstance(expr, NameExpr) and expr.fullname == 'builtins.None' - - -def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]: - metaclass_sym = info.names.get('Meta') - if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo): - return metaclass_sym.node - return None - - -def get_assignment_stmt_by_name(type_info: TypeInfo, name: str) -> Optional[Expression]: - for assignment_name, call_expr in iter_over_class_level_assignments(type_info.defn): - if assignment_name == name: - return call_expr - return None - - -def is_field_nullable(model: TypeInfo, field_name: str) -> bool: - return metadata.get_fields_metadata(model).get(field_name, {}).get('null', False) - - -def is_foreign_key_like(t: MypyType) -> bool: - if not isinstance(t, Instance): - return False - return has_any_of_bases(t.type, {fullnames.FOREIGN_KEY_FULLNAME, fullnames.ONETOONE_FIELD_FULLNAME}) - - -def build_class_with_annotated_fields(api: 'TypeChecker', base: MypyType, fields: 'OrderedDict[str, MypyType]', - name: str) -> Instance: - """Build an Instance with `name` that contains the specified `fields` as attributes and extends `base`.""" - # Credit: This code is largely copied/modified from TypeChecker.intersect_instance_callable and - # NamedTupleAnalyzer.build_namedtuple_typeinfo - from mypy.checker import gen_unique_name - - cur_module = cast(MypyFile, api.scope.stack[0]) - gen_name = gen_unique_name(name, cur_module.names) - - cdef = ClassDef(name, Block([])) - cdef.fullname = cur_module.fullname() + '.' + gen_name - info = TypeInfo(SymbolTable(), cdef, cur_module.fullname()) - cdef.info = info - info.bases = [base] - - def add_field(var: Var, is_initialized_in_class: bool = False, - is_property: bool = False) -> None: - var.info = info - var.is_initialized_in_class = is_initialized_in_class - var.is_property = is_property - var._fullname = '%s.%s' % (info.fullname(), var.name()) - info.names[var.name()] = SymbolTableNode(MDEF, var) - - vars = [Var(item, typ) for item, typ in fields.items()] - for var in vars: - add_field(var, is_property=True) - - calculate_mro(info) - info.calculate_metaclass_type() - - cur_module.names[gen_name] = SymbolTableNode(GDEF, info, plugin_generated=True) - return Instance(info, []) - - -def make_named_tuple(api: 'TypeChecker', fields: 'OrderedDict[str, MypyType]', name: str) -> MypyType: - if not fields: - # No fields specified, so fallback to a subclass of NamedTuple that allows - # __getattr__ / __setattr__ for any attribute name. - fallback = api.named_generic_type('django._NamedTupleAnyAttr', []) - else: - fallback = build_class_with_annotated_fields( - api=api, - base=api.named_generic_type('NamedTuple', []), - fields=fields, - name=name - ) - return TupleType(list(fields.values()), fallback=fallback) - - -def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyType]', - required_keys: Set[str]) -> TypedDictType: - object_type = api.named_generic_type('mypy_extensions._TypedDict', []) - typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type) - return typed_dict_type - - -def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType: - implicit_any = AnyType(TypeOfAny.special_form) - fallback = api.named_generic_type('builtins.tuple', [implicit_any]) - return TupleType(fields, fallback=fallback) - - -def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType: - node = type_info.get(private_field_name).node - if isinstance(node, Var): - descriptor_type = node.type - if is_nullable: - descriptor_type = make_optional(descriptor_type) - return descriptor_type - return AnyType(TypeOfAny.unannotated) - - -class IncompleteDefnException(Exception): - pass - - -def iter_over_toplevel_classes(module_file: MypyFile) -> Iterator[ClassDef]: - for defn in module_file.defs: - if isinstance(defn, ClassDef): - yield defn - - -def iter_call_assignments_in_class(klass: ClassDef) -> Iterator[Tuple[str, CallExpr]]: - for name, expression in iter_over_assignments_in_class(klass): - if isinstance(expression, CallExpr): - yield name, expression - - -def iter_over_field_inits_in_class(klass: ClassDef) -> Iterator[Tuple[str, CallExpr]]: - for lvalue, rvalue in iter_over_assignments_in_class(klass): - if isinstance(lvalue, NameExpr) and isinstance(rvalue, CallExpr): - field_name = lvalue.name - if isinstance(rvalue.callee, MemberExpr) and isinstance(rvalue.callee.node, TypeInfo): - if isinstance(rvalue.callee.node, FakeInfo): - raise IncompleteDefnException() - - field_info = rvalue.callee.node - if field_info.has_base(fullnames.FIELD_FULLNAME): - yield field_name, rvalue - - -def get_related_manager_type_from_metadata(model_info: TypeInfo, related_manager_name: str, - api: CheckerPluginInterface) -> Optional[Instance]: - related_manager_metadata = metadata.get_related_managers_metadata(model_info) - if not related_manager_metadata: - return None - - if related_manager_name not in related_manager_metadata: - return None - - manager_class_name = related_manager_metadata[related_manager_name]['manager'] - of = related_manager_metadata[related_manager_name]['of'] - of_types = [] - for of_type_name in of: - if of_type_name == 'any': - of_types.append(AnyType(TypeOfAny.implementation_artifact)) - else: - try: - of_type = api.named_generic_type(of_type_name, []) - except AssertionError: - # Internal error: attempted lookup of unknown name - of_type = AnyType(TypeOfAny.implementation_artifact) - - of_types.append(of_type) - - return api.named_generic_type(manager_class_name, of_types) - - -def get_primary_key_field_name(model_info: TypeInfo) -> Optional[str]: - for base in model_info.mro: - fields = metadata.get_fields_metadata(base) - for field_name, field_props in fields.items(): - is_primary_key = field_props.get('primary_key', False) - if is_primary_key: - return field_name - return None - - -def _get_app_models_file(app_name: str, all_modules: Dict[str, MypyFile]) -> Optional[MypyFile]: - models_module = '.'.join([app_name, 'models']) - return all_modules.get(models_module) - - -def get_model_info(app_name_dot_model_name: str, all_modules: Dict[str, MypyFile]) -> Optional[TypeInfo]: - """ Resolve app_name.ModelName into model fullname """ - app_name, model_name = app_name_dot_model_name.split('.') - models_file = _get_app_models_file(app_name, all_modules) - if models_file is None: - return None - - sym = models_file.names.get(model_name) - if sym and isinstance(sym.node, TypeInfo): - return sym.node diff --git a/mypy_django_plugin/lib/lookups.py b/mypy_django_plugin/lib/lookups.py deleted file mode 100644 index 27035ba..0000000 --- a/mypy_django_plugin/lib/lookups.py +++ /dev/null @@ -1,159 +0,0 @@ -from typing import List, Union - -import dataclasses -from mypy.nodes import TypeInfo -from mypy.plugin import CheckerPluginInterface -from mypy.types import Instance, Type - -from mypy_django_plugin.lib import metadata, helpers - - -@dataclasses.dataclass -class RelatedModelNode: - typ: Instance - is_nullable: bool - - -@dataclasses.dataclass -class FieldNode: - typ: Type - - -LookupNode = Union[RelatedModelNode, FieldNode] - - -class LookupException(Exception): - pass - - -def resolve_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, lookup: str) -> List[LookupNode]: - """Resolve a lookup str to a list of LookupNodes. - - Each node represents a part of the lookup (separated by "__"), in order. - Each node is the Model or Field that was resolved. - - Raises LookupException if there were any issues resolving the lookup. - """ - lookup_parts = lookup.split("__") - - nodes = [] - while lookup_parts: - lookup_part = lookup_parts.pop(0) - - if not nodes: - current_node = None - else: - current_node = nodes[-1] - - if current_node is None: - new_node = resolve_model_lookup(api, model_type_info, lookup_part) - elif isinstance(current_node, RelatedModelNode): - new_node = resolve_model_lookup(api, current_node.typ.type, lookup_part) - elif isinstance(current_node, FieldNode): - raise LookupException(f"Field lookups not yet supported for lookup {lookup}") - else: - raise LookupException(f"Unsupported node type: {type(current_node)}") - nodes.append(new_node) - return nodes - - -def resolve_model_pk_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo) -> LookupNode: - # Primary keys are special-cased - primary_key_type = helpers.extract_primary_key_type_for_get(model_type_info) - if primary_key_type: - return FieldNode(primary_key_type) - else: - # No PK, use the get type for AutoField as PK type. - autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField') - pk_type = helpers.get_private_descriptor_type(autofield_info, '_pyi_private_get_type', - is_nullable=False) - return FieldNode(pk_type) - - -def resolve_model_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, - lookup: str) -> LookupNode: - """Resolve a lookup on the given model.""" - if lookup == 'pk': - return resolve_model_pk_lookup(api, model_type_info) - - field_name = get_actual_field_name_for_lookup_field(lookup, model_type_info) - - field_node = model_type_info.get(field_name) - if not field_node: - raise LookupException( - f'When resolving lookup "{lookup}", field "{field_name}" was not found in model {model_type_info.name()}') - - if field_name.endswith('_id'): - field_name_without_id = field_name.rstrip('_id') - foreign_key_field = model_type_info.get(field_name_without_id) - if foreign_key_field is not None and helpers.is_foreign_key_like(foreign_key_field.type): - # Hack: If field ends with '_id' and there is a model field without the '_id' suffix, then use that field. - field_node = foreign_key_field - field_name = field_name_without_id - - field_node_type = field_node.type - if field_node_type is None or not isinstance(field_node_type, Instance): - raise LookupException( - f'When resolving lookup "{lookup}", could not determine type for {model_type_info.name()}.{field_name}') - - if field_node_type.type.fullname() == 'builtins.object': - # could be related manager - related_manager_type = helpers.get_related_manager_type_from_metadata(model_type_info, field_name, api) - if related_manager_type: - model_arg = related_manager_type.args[0] - if not isinstance(model_arg, Instance): - raise LookupException( - f'When resolving lookup "{lookup}", could not determine type ' - f'for {model_type_info.name()}.{field_name}') - - return RelatedModelNode(typ=model_arg, is_nullable=False) - - if helpers.is_foreign_key_like(field_node_type): - field_type = helpers.extract_field_getter_type(field_node_type) - is_nullable = helpers.is_optional(field_type) - if is_nullable: - # type is always non-optional - field_type = helpers.make_required(field_type) - - if isinstance(field_type, Instance): - return RelatedModelNode(typ=field_type, is_nullable=is_nullable) - else: - raise LookupException(f"Not an instance for field {field_type} lookup {lookup}") - - field_type = helpers.extract_field_getter_type(field_node_type) - if field_type: - return FieldNode(typ=field_type) - - # Not a Field - if field_name == 'id': - # If no 'id' field was found, use an int - return FieldNode(api.named_generic_type('builtins.int', [])) - - raise LookupException( - f'When resolving lookup {lookup!r}, could not determine type for {model_type_info.name()}.{field_name}') - - -def get_actual_field_name_for_lookup_field(lookup: str, model_type_info: TypeInfo) -> str: - """Attempt to find out the real field name if this lookup is a related_query_name (for reverse relations). - - If it's not, return the original lookup. - """ - lookups_metadata = metadata.get_lookups_metadata(model_type_info) - lookup_metadata = lookups_metadata.get(lookup) - if lookup_metadata is None: - # If not found on current model, look in all bases for their lookup metadata - for base in model_type_info.mro: - lookups_metadata = metadata.get_lookups_metadata(base) - lookup_metadata = lookups_metadata.get(lookup) - if lookup_metadata: - break - if not lookup_metadata: - lookup_metadata = {} - related_name = lookup_metadata.get('related_query_name_target', None) - if related_name: - # If the lookup is a related lookup, then look at the field specified by related_name. - # This is to support if related_query_name is set and differs from. - field_name = related_name - else: - field_name = lookup - return field_name diff --git a/mypy_django_plugin/lib/metadata.py b/mypy_django_plugin/lib/metadata.py deleted file mode 100644 index 670f769..0000000 --- a/mypy_django_plugin/lib/metadata.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Any, Dict, List - -from mypy.nodes import TypeInfo - - -def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]: - return model_info.metadata.setdefault('django', {}) - - -def get_related_field_primary_key_names(base_model: TypeInfo) -> List[str]: - return get_django_metadata(base_model).setdefault('related_field_primary_keys', []) - - -def get_fields_metadata(model: TypeInfo) -> Dict[str, Any]: - return get_django_metadata(model).setdefault('fields', {}) - - -def get_lookups_metadata(model: TypeInfo) -> Dict[str, Any]: - return get_django_metadata(model).setdefault('lookups', {}) - - -def get_related_managers_metadata(model: TypeInfo) -> Dict[str, Any]: - return get_django_metadata(model).setdefault('related_managers', {}) diff --git a/mypy_django_plugin/lib/tests/__init__.py b/mypy_django_plugin/lib/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mypy_django_plugin/lib/tests/sample_django_project/manage.py b/mypy_django_plugin/lib/tests/sample_django_project/manage.py deleted file mode 100755 index 4d2075b..0000000 --- a/mypy_django_plugin/lib/tests/sample_django_project/manage.py +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env python -"""Django's command-line utility for administrative tasks.""" -import os -import sys - - -def main(): - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'sample_django_project.settings') - try: - from django.core.management import execute_from_command_line - except ImportError as exc: - raise ImportError( - "Couldn't import Django. Are you sure it's installed and " - "available on your PYTHONPATH environment variable? Did you " - "forget to activate a virtual environment?" - ) from exc - execute_from_command_line(sys.argv) - - -if __name__ == '__main__': - main() diff --git a/mypy_django_plugin/lib/tests/sample_django_project/myapp/__init__.py b/mypy_django_plugin/lib/tests/sample_django_project/myapp/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mypy_django_plugin/lib/tests/sample_django_project/myapp/admin.py b/mypy_django_plugin/lib/tests/sample_django_project/myapp/admin.py deleted file mode 100644 index 8c38f3f..0000000 --- a/mypy_django_plugin/lib/tests/sample_django_project/myapp/admin.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.contrib import admin - -# Register your models here. diff --git a/mypy_django_plugin/lib/tests/sample_django_project/myapp/apps.py b/mypy_django_plugin/lib/tests/sample_django_project/myapp/apps.py deleted file mode 100644 index 2370402..0000000 --- a/mypy_django_plugin/lib/tests/sample_django_project/myapp/apps.py +++ /dev/null @@ -1,5 +0,0 @@ -from django.apps import AppConfig - - -class MyappConfig(AppConfig): - label = 'myapp22' diff --git a/mypy_django_plugin/lib/tests/sample_django_project/myapp/migrations/__init__.py b/mypy_django_plugin/lib/tests/sample_django_project/myapp/migrations/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mypy_django_plugin/lib/tests/sample_django_project/myapp/models.py b/mypy_django_plugin/lib/tests/sample_django_project/myapp/models.py deleted file mode 100644 index 4e0f4de..0000000 --- a/mypy_django_plugin/lib/tests/sample_django_project/myapp/models.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.db import models - - -# Create your models here. -class MyModel(models.Model): - pass diff --git a/mypy_django_plugin/lib/tests/sample_django_project/myapp/tests.py b/mypy_django_plugin/lib/tests/sample_django_project/myapp/tests.py deleted file mode 100644 index 7ce503c..0000000 --- a/mypy_django_plugin/lib/tests/sample_django_project/myapp/tests.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.test import TestCase - -# Create your tests here. diff --git a/mypy_django_plugin/lib/tests/sample_django_project/myapp/views.py b/mypy_django_plugin/lib/tests/sample_django_project/myapp/views.py deleted file mode 100644 index 91ea44a..0000000 --- a/mypy_django_plugin/lib/tests/sample_django_project/myapp/views.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.shortcuts import render - -# Create your views here. diff --git a/mypy_django_plugin/lib/tests/sample_django_project/root/__init__.py b/mypy_django_plugin/lib/tests/sample_django_project/root/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mypy_django_plugin/lib/tests/sample_django_project/root/settings.py b/mypy_django_plugin/lib/tests/sample_django_project/root/settings.py deleted file mode 100644 index e04938c..0000000 --- a/mypy_django_plugin/lib/tests/sample_django_project/root/settings.py +++ /dev/null @@ -1,121 +0,0 @@ -""" -Django settings for sample_django_project project. - -Generated by 'django-admin startproject' using Django 2.2.3. - -For more information on this file, see -https://docs.djangoproject.com/en/2.2/topics/settings/ - -For the full list of settings and their values, see -https://docs.djangoproject.com/en/2.2/ref/settings/ -""" - -import os - -# Build paths inside the project like this: os.path.join(BASE_DIR, ...) -BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - - -# Quick-start development settings - unsuitable for production -# See https://docs.djangoproject.com/en/2.2/howto/deployment/checklist/ - -# SECURITY WARNING: keep the secret key used in production secret! -SECRET_KEY = 'e6gj!2x(*odqwmjafrn7#35%)&rnn&^*0x-f&j0prgr--&xf+%' - -# SECURITY WARNING: don't run with debug turned on in production! -DEBUG = True - -ALLOWED_HOSTS = [] - - -# Application definition - -INSTALLED_APPS = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'mypy_django_plugin.lib.tests.sample_django_project.myapp' -] - -MIDDLEWARE = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', -] - -ROOT_URLCONF = 'sample_django_project.urls' - -TEMPLATES = [ - { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', - ], - }, - }, -] - -WSGI_APPLICATION = 'sample_django_project.wsgi.application' - - -# Database -# https://docs.djangoproject.com/en/2.2/ref/settings/#databases - -DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), - } -} - - -# Password validation -# https://docs.djangoproject.com/en/2.2/ref/settings/#auth-password-validators - -AUTH_PASSWORD_VALIDATORS = [ - { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', - }, -] - - -# Internationalization -# https://docs.djangoproject.com/en/2.2/topics/i18n/ - -LANGUAGE_CODE = 'en-us' - -TIME_ZONE = 'UTC' - -USE_I18N = True - -USE_L10N = True - -USE_TZ = True - - -# Static files (CSS, JavaScript, Images) -# https://docs.djangoproject.com/en/2.2/howto/static-files/ - -STATIC_URL = '/static/' diff --git a/mypy_django_plugin/lib/tests/sample_django_project/root/urls.py b/mypy_django_plugin/lib/tests/sample_django_project/root/urls.py deleted file mode 100644 index d5e74b1..0000000 --- a/mypy_django_plugin/lib/tests/sample_django_project/root/urls.py +++ /dev/null @@ -1,21 +0,0 @@ -"""sample_django_project URL Configuration - -The `urlpatterns` list routes URLs to views. For more information please see: - https://docs.djangoproject.com/en/2.2/topics/http/urls/ -Examples: -Function views - 1. Add an import: from my_app import views - 2. Add a URL to urlpatterns: path('', views.home, name='home') -Class-based views - 1. Add an import: from other_app.views import Home - 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') -Including another URLconf - 1. Import the include() function: from django.urls import include, path - 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) -""" -from django.contrib import admin -from django.urls import path - -urlpatterns = [ - path('admin/', admin.site.urls), -] diff --git a/mypy_django_plugin/lib/tests/sample_django_project/root/wsgi.py b/mypy_django_plugin/lib/tests/sample_django_project/root/wsgi.py deleted file mode 100644 index 6ae29e8..0000000 --- a/mypy_django_plugin/lib/tests/sample_django_project/root/wsgi.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -WSGI config for sample_django_project project. - -It exposes the WSGI callable as a module-level variable named ``application``. - -For more information on this file, see -https://docs.djangoproject.com/en/2.2/howto/deployment/wsgi/ -""" - -import os - -from django.core.wsgi import get_wsgi_application - -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'sample_django_project.settings') - -application = get_wsgi_application() diff --git a/mypy_django_plugin/lib/tests/test_get_app_configs.py b/mypy_django_plugin/lib/tests/test_get_app_configs.py deleted file mode 100644 index f2e4c00..0000000 --- a/mypy_django_plugin/lib/tests/test_get_app_configs.py +++ /dev/null @@ -1,14 +0,0 @@ -from mypy.options import Options - -from mypy_django_plugin.lib.config import extract_app_model_aliases -from mypy_django_plugin.main import DjangoPlugin - - -def test_parse_django_settings(): - app_model_mapping = extract_app_model_aliases('mypy_django_plugin.lib.tests.sample_django_project.root.settings') - assert app_model_mapping['myapp.MyModel'] == 'mypy_django_plugin.lib.tests.sample_django_project.myapp.models.MyModel' - - -def test_instantiate_plugin_with_config(): - plugin = DjangoPlugin(Options()) - diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py deleted file mode 100644 index 809b5e4..0000000 --- a/mypy_django_plugin/main.py +++ /dev/null @@ -1,313 +0,0 @@ -import os -from functools import partial -from typing import Callable, Dict, List, Optional, Tuple, cast - -import toml -from mypy.nodes import MypyFile, NameExpr, TypeInfo -from mypy.options import Options -from mypy.plugin import ( - AnalyzeTypeContext, AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin, -) -from mypy.types import AnyType, Instance, Type, TypeOfAny - -from mypy_django_plugin.lib import metadata, fullnames, helpers -from mypy_django_plugin.lib.config import Config, extract_app_model_aliases -from mypy_django_plugin.transformers import fields, init_create -from mypy_django_plugin.transformers.forms import ( - extract_proper_type_for_get_form, extract_proper_type_for_get_form_class, make_meta_nested_class_inherit_from_any, -) -from mypy_django_plugin.transformers.migrations import ( - determine_model_cls_from_string_for_migrations, -) -from mypy_django_plugin.transformers.models import process_model_class -from mypy_django_plugin.transformers.queryset import ( - extract_proper_type_for_queryset_values, extract_proper_type_queryset_values_list, - set_first_generic_param_as_default_for_second, -) -from mypy_django_plugin.transformers.related import ( - determine_type_of_related_manager, extract_and_return_primary_key_of_bound_related_field_parameter, -) -from mypy_django_plugin.transformers.settings import ( - get_type_of_setting, return_user_model_hook, -) - - -def transform_model_class(ctx: ClassDefContext, - ignore_missing_model_attributes: bool, - app_models_mapping: Optional[Dict[str, str]]) -> None: - try: - sym = ctx.api.lookup_fully_qualified(fullnames.MODEL_CLASS_FULLNAME) - except KeyError: - # models.Model is not loaded, skip metadata model write - pass - else: - if sym is not None and isinstance(sym.node, TypeInfo): - metadata.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1 - - process_model_class(ctx, ignore_missing_model_attributes, app_models_mapping) - - -def transform_manager_class(ctx: ClassDefContext) -> None: - sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME) - if sym is not None and isinstance(sym.node, TypeInfo): - metadata.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1 - - -def transform_form_class(ctx: ClassDefContext) -> None: - sym = ctx.api.lookup_fully_qualified_or_none(fullnames.BASEFORM_CLASS_FULLNAME) - if sym is not None and isinstance(sym.node, TypeInfo): - metadata.get_django_metadata(sym.node)['baseform_bases'][ctx.cls.fullname] = 1 - - make_meta_nested_class_inherit_from_any(ctx) - - -def determine_proper_manager_type(ctx: FunctionContext) -> Type: - from mypy.checker import TypeChecker - - api = cast(TypeChecker, ctx.api) - ret = ctx.default_return_type - if not api.tscope.classes: - # not in class - return ret - outer_model_info = api.tscope.classes[0] - if not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME): - return ret - if not isinstance(ret, Instance): - return ret - - has_manager_base = False - for i, base in enumerate(ret.type.bases): - if base.type.fullname() in {fullnames.MANAGER_CLASS_FULLNAME, - fullnames.RELATED_MANAGER_CLASS_FULLNAME, - fullnames.BASE_MANAGER_CLASS_FULLNAME}: - has_manager_base = True - break - - if has_manager_base: - # Fill in the manager's type argument from the outer model - new_type_args = [Instance(outer_model_info, [])] - return helpers.reparametrize_instance(ret, new_type_args) - else: - return ret - - -def return_type_for_id_field(ctx: AttributeContext) -> Type: - if not isinstance(ctx.type, Instance): - return AnyType(TypeOfAny.from_error) - - model_info = ctx.type.type # type: TypeInfo - primary_key_field_name = helpers.get_primary_key_field_name(model_info) - if not primary_key_field_name: - # no field with primary_key=True, just return id as int - return ctx.api.named_generic_type('builtins.int', []) - - if primary_key_field_name != 'id': - # there's field with primary_key=True, but it's name is not 'id', fail - ctx.api.fail("Default primary key 'id' is not defined", ctx.context) - return AnyType(TypeOfAny.from_error) - - primary_key_sym = model_info.get(primary_key_field_name) - if primary_key_sym and isinstance(primary_key_sym.type, Instance): - pass - - # try to parse field type out of primary key field - field_type = helpers.extract_field_getter_type(primary_key_sym.type) - if field_type: - return field_type - - return primary_key_sym.type - - -def transform_form_view(ctx: ClassDefContext) -> None: - form_class_value = helpers.get_assignment_stmt_by_name(ctx.cls.info, 'form_class') - if isinstance(form_class_value, NameExpr): - metadata.get_django_metadata(ctx.cls.info)['form_class'] = form_class_value.fullname - - -class DjangoPlugin(Plugin): - def __init__(self, options: Options) -> None: - super().__init__(options) - - django_plugin_config = None - if os.path.exists('pyproject.toml'): - with open('pyproject.toml', 'r') as f: - pyproject_toml = toml.load(f) - django_plugin_config = pyproject_toml.get('tool', {}).get('django-stubs') - - if django_plugin_config and 'django_settings_module' in django_plugin_config: - self.app_models_mapping = extract_app_model_aliases(django_plugin_config['django_settings_module']) - else: - self.app_models_mapping = None - - config_fpath = os.environ.get('MYPY_DJANGO_CONFIG', 'mypy_django.ini') - if config_fpath and os.path.exists(config_fpath): - self.config = Config.from_config_file(config_fpath) - self.django_settings_module = self.config.django_settings_module - else: - self.config = Config() - self.django_settings_module = None - - if 'DJANGO_SETTINGS_MODULE' in os.environ: - self.django_settings_module = os.environ['DJANGO_SETTINGS_MODULE'] - - def _get_current_model_bases(self) -> Dict[str, int]: - model_sym = self.lookup_fully_qualified(fullnames.MODEL_CLASS_FULLNAME) - if model_sym is not None and isinstance(model_sym.node, TypeInfo): - return (metadata.get_django_metadata(model_sym.node) - .setdefault('model_bases', {fullnames.MODEL_CLASS_FULLNAME: 1})) - else: - return {} - - def _get_current_manager_bases(self) -> Dict[str, int]: - model_sym = self.lookup_fully_qualified(fullnames.MANAGER_CLASS_FULLNAME) - if model_sym is not None and isinstance(model_sym.node, TypeInfo): - return (metadata.get_django_metadata(model_sym.node) - .setdefault('manager_bases', {fullnames.MANAGER_CLASS_FULLNAME: 1})) - else: - return {} - - def _get_current_form_bases(self) -> Dict[str, int]: - model_sym = self.lookup_fully_qualified(fullnames.BASEFORM_CLASS_FULLNAME) - if model_sym is not None and isinstance(model_sym.node, TypeInfo): - return (metadata.get_django_metadata(model_sym.node) - .setdefault('baseform_bases', {fullnames.BASEFORM_CLASS_FULLNAME: 1, - fullnames.FORM_CLASS_FULLNAME: 1, - fullnames.MODELFORM_CLASS_FULLNAME: 1})) - else: - return {} - - def _get_current_queryset_bases(self) -> Dict[str, int]: - model_sym = self.lookup_fully_qualified(fullnames.QUERYSET_CLASS_FULLNAME) - if model_sym is not None and isinstance(model_sym.node, TypeInfo): - return (metadata.get_django_metadata(model_sym.node) - .setdefault('queryset_bases', {fullnames.QUERYSET_CLASS_FULLNAME: 1})) - else: - return {} - - def _get_settings_modules_in_order_of_priority(self) -> List[str]: - settings_modules = [] - if self.django_settings_module: - settings_modules.append(self.django_settings_module) - - settings_modules.append('django.conf.global_settings') - return settings_modules - - def _get_typeinfo_or_none(self, class_name: str) -> Optional[TypeInfo]: - sym = self.lookup_fully_qualified(class_name) - if sym is not None and isinstance(sym.node, TypeInfo): - return sym.node - return None - - def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]: - if file.fullname() == 'django.conf' and self.django_settings_module: - return [(10, self.django_settings_module, -1)] - - if file.fullname() == 'django.db.models.query': - return [(10, 'mypy_extensions', -1)] - - return [] - - def get_function_hook(self, fullname: str - ) -> Optional[Callable[[FunctionContext], Type]]: - # if fullname == 'django.contrib.auth.get_user_model': - # return partial(return_user_model_hook, - # settings_modules=self._get_settings_modules_in_order_of_priority()) - # - manager_bases = self._get_current_manager_bases() - if fullname in manager_bases: - return determine_proper_manager_type - - info = self._get_typeinfo_or_none(fullname) - if info: - if info.has_base(fullnames.FIELD_FULLNAME): - return fields.process_field_instantiation - - # if metadata.get_django_metadata(info).get('generated_init'): - # return init_create.redefine_and_typecheck_model_init - - # def get_method_hook(self, fullname: str - # ) -> Optional[Callable[[MethodContext], Type]]: - # class_name, _, method_name = fullname.rpartition('.') - # - # if method_name == 'get_form_class': - # info = self._get_typeinfo_or_none(class_name) - # if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME): - # return extract_proper_type_for_get_form_class - # - # if method_name == 'get_form': - # info = self._get_typeinfo_or_none(class_name) - # if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME): - # return extract_proper_type_for_get_form - # - # if method_name == 'values': - # model_info = self._get_typeinfo_or_none(class_name) - # if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): - # return extract_proper_type_for_queryset_values - # - # if method_name == 'values_list': - # model_info = self._get_typeinfo_or_none(class_name) - # if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): - # return extract_proper_type_queryset_values_list - # - # if fullname in {'django.apps.registry.Apps.get_model', - # 'django.db.migrations.state.StateApps.get_model'}: - # return determine_model_cls_from_string_for_migrations - # - # manager_classes = self._get_current_manager_bases() - # class_fullname, _, method_name = fullname.rpartition('.') - # if class_fullname in manager_classes and method_name == 'create': - # return init_create.redefine_and_typecheck_model_create - - def get_base_class_hook(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: - if fullname in self._get_current_model_bases(): - return partial(transform_model_class, - ignore_missing_model_attributes=self.config.ignore_missing_model_attributes, - app_models_mapping=self.app_models_mapping) - - if fullname in self._get_current_manager_bases(): - return transform_manager_class - - # if fullname in self._get_current_form_bases(): - # return transform_form_class - - # info = self._get_typeinfo_or_none(fullname) - # if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME): - # return transform_form_view - - return None - - def get_attribute_hook(self, fullname: str - ) -> Optional[Callable[[AttributeContext], Type]]: - class_name, _, attr_name = fullname.rpartition('.') - # if class_name == fullnames.DUMMY_SETTINGS_BASE_CLASS: - # return partial(get_type_of_setting, - # setting_name=attr_name, - # settings_modules=self._get_settings_modules_in_order_of_priority(), - # ignore_missing_settings=self.config.ignore_missing_settings) - - if class_name in self._get_current_model_bases(): - if attr_name == 'id': - return return_type_for_id_field - - model_info = self._get_typeinfo_or_none(class_name) - if model_info: - related_managers = metadata.get_related_managers_metadata(model_info) - if attr_name in related_managers: - return partial(determine_type_of_related_manager, - related_manager_name=attr_name) - - if attr_name.endswith('_id'): - return extract_and_return_primary_key_of_bound_related_field_parameter - - def get_type_analyze_hook(self, fullname: str - ) -> Optional[Callable[[AnalyzeTypeContext], Type]]: - queryset_bases = self._get_current_queryset_bases() - if fullname in queryset_bases: - return partial(set_first_generic_param_as_default_for_second, fullname) - - return None - - -def plugin(version): - return DjangoPlugin diff --git a/mypy_django_plugin/transformers/__init__.py b/mypy_django_plugin/transformers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py deleted file mode 100644 index 886c835..0000000 --- a/mypy_django_plugin/transformers/fields.py +++ /dev/null @@ -1,231 +0,0 @@ -from typing import Optional, cast - -from mypy.checker import TypeChecker -from mypy.nodes import ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo, Expression -from mypy.plugin import FunctionContext -from mypy.types import ( - AnyType, CallableType, Instance, TupleType, Type, UnionType, -) - -from mypy_django_plugin.lib import fullnames, helpers, metadata - - -def extract_referred_to_type(ctx: FunctionContext) -> Optional[Instance]: - api = cast(TypeChecker, ctx.api) - if 'to' not in ctx.callee_arg_names: - api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}', - context=ctx.context) - return None - - arg_type = ctx.arg_types[ctx.callee_arg_names.index('to')][0] - if not isinstance(arg_type, CallableType): - to_arg_expr = ctx.args[ctx.callee_arg_names.index('to')][0] - if not isinstance(to_arg_expr, StrExpr): - # not string, not supported - return None - try: - model_fullname = helpers.get_model_fullname_from_string(to_arg_expr.value, - all_modules=api.modules) - except helpers.SelfReference: - model_fullname = api.tscope.classes[-1].fullname() - - except helpers.SameFileModel as exc: - model_fullname = api.tscope.classes[-1].module_name + '.' + exc.model_cls_name - - if model_fullname is None: - return None - model_info = helpers.lookup_fully_qualified_generic(model_fullname, - all_modules=api.modules) - if model_info is None or not isinstance(model_info, TypeInfo): - return None - return Instance(model_info, []) - - referred_to_type = arg_type.ret_type - if not isinstance(referred_to_type, Instance): - return None - if not referred_to_type.type.has_base(fullnames.MODEL_CLASS_FULLNAME): - ctx.api.msg.fail(f'to= parameter value must be ' - f'a subclass of {fullnames.MODEL_CLASS_FULLNAME!r}', - context=ctx.context) - return None - - return referred_to_type - - -def convert_any_to_type(typ: Type, referred_to_type: Type) -> Type: - if isinstance(typ, UnionType): - converted_items = [] - for item in typ.items: - converted_items.append(convert_any_to_type(item, referred_to_type)) - return UnionType.make_union(converted_items, - line=typ.line, column=typ.column) - if isinstance(typ, Instance): - args = [] - for default_arg in typ.args: - if isinstance(default_arg, AnyType): - args.append(referred_to_type) - else: - args.append(default_arg) - return helpers.reparametrize_instance(typ, args) - - if isinstance(typ, AnyType): - return referred_to_type - - return typ - - -def fill_descriptor_types_for_related_field(ctx: FunctionContext) -> Type: - default_return_type = set_descriptor_types_for_field(ctx) - referred_to_type = extract_referred_to_type(ctx) - if referred_to_type is None: - return default_return_type - - # replace Any with referred_to_type - args = [] - for default_arg in default_return_type.args: - args.append(convert_any_to_type(default_arg, referred_to_type)) - - return helpers.reparametrize_instance(ctx.default_return_type, new_args=args) - - -def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance: - default_return_type = cast(Instance, ctx.default_return_type) - is_nullable = helpers.parse_bool(helpers.get_call_argument_by_name(ctx, 'null')) - set_type = helpers.get_private_descriptor_type(default_return_type.type, '_pyi_private_set_type', - is_nullable=is_nullable) - get_type = helpers.get_private_descriptor_type(default_return_type.type, '_pyi_private_get_type', - is_nullable=is_nullable) - return helpers.reparametrize_instance(default_return_type, [set_type, get_type]) - - -def determine_type_of_array_field(ctx: FunctionContext) -> Type: - default_return_type = set_descriptor_types_for_field(ctx) - - base_field_arg_type = helpers.get_call_argument_type_by_name(ctx, 'base_field') - if not base_field_arg_type or not isinstance(base_field_arg_type, Instance): - return default_return_type - - base_type = base_field_arg_type.args[1] # extract __get__ type - args = [] - for default_arg in default_return_type.args: - args.append(convert_any_to_type(default_arg, base_type)) - - return helpers.reparametrize_instance(default_return_type, args) - - -def transform_into_proper_return_type(ctx: FunctionContext) -> Type: - default_return_type = ctx.default_return_type - if not isinstance(default_return_type, Instance): - return default_return_type - - if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES): - return fill_descriptor_types_for_related_field(ctx) - - if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME): - return determine_type_of_array_field(ctx) - - return set_descriptor_types_for_field(ctx) - - -def process_field_instantiation(ctx: FunctionContext) -> Type: - # Parse __init__ parameters of field into corresponding Model's metadata - parse_field_init_arguments_into_model_metadata(ctx) - return transform_into_proper_return_type(ctx) - - -def _parse_choices_type(ctx: FunctionContext, choices_arg: Expression) -> Optional[str]: - if isinstance(choices_arg, (TupleExpr, ListExpr)): - # iterable of 2 element tuples of two kinds - _, analyzed_choices = ctx.api.analyze_iterable_item_type(choices_arg) - if isinstance(analyzed_choices, TupleType): - first_element_type = analyzed_choices.items[0] - if isinstance(first_element_type, Instance): - return first_element_type.type.fullname() - - -def _parse_referenced_model(ctx: FunctionContext, to_arg: Expression) -> Optional[TypeInfo]: - if isinstance(to_arg, NameExpr) and isinstance(to_arg.node, TypeInfo): - # reference to the model class - return to_arg.node - - elif isinstance(to_arg, StrExpr): - referenced_model_info = helpers.get_model_info(to_arg.value, ctx.api.modules) - if referenced_model_info is not None: - return referenced_model_info - - -def parse_field_init_arguments_into_model_metadata(ctx: FunctionContext) -> None: - outer_model = ctx.api.scope.active_class() - if outer_model is None or not outer_model.has_base(fullnames.MODEL_CLASS_FULLNAME): - # outside models.Model class, undetermined - return - - # Determine name of the current field - for attr_name, stmt in helpers.iter_over_class_level_assignments(outer_model.defn): - if stmt == ctx.context: - field_name = attr_name - break - else: - return - - model_fields_metadata = metadata.get_fields_metadata(outer_model) - - # primary key - is_primary_key = False - primary_key_arg = helpers.get_call_argument_by_name(ctx, 'primary_key') - if primary_key_arg: - is_primary_key = helpers.parse_bool(primary_key_arg) - model_fields_metadata[field_name] = {'primary_key': is_primary_key} - - # choices - choices_arg = helpers.get_call_argument_by_name(ctx, 'choices') - if choices_arg: - choices_type_fullname = _parse_choices_type(ctx.api, choices_arg) - if choices_type_fullname: - model_fields_metadata[field_name]['choices_type'] = choices_type_fullname - - # nullability - null_arg = helpers.get_call_argument_by_name(ctx, 'null') - is_nullable = False - if null_arg: - is_nullable = helpers.parse_bool(null_arg) - model_fields_metadata[field_name]['null'] = is_nullable - - # is_blankable - blank_arg = helpers.get_call_argument_by_name(ctx, 'blank') - is_blankable = False - if blank_arg: - is_blankable = helpers.parse_bool(blank_arg) - model_fields_metadata[field_name]['blank'] = is_blankable - - # default - default_arg = helpers.get_call_argument_by_name(ctx, 'default') - if default_arg and not helpers.is_none_expr(default_arg): - model_fields_metadata[field_name]['default_specified'] = True - - if helpers.has_any_of_bases(ctx.default_return_type.type, fullnames.RELATED_FIELDS_CLASSES): - # to - to_arg = helpers.get_call_argument_by_name(ctx, 'to') - if to_arg: - referenced_model = _parse_referenced_model(ctx, to_arg) - if referenced_model is not None: - model_fields_metadata[field_name]['to'] = referenced_model.fullname() - else: - model_fields_metadata[field_name]['to'] = to_arg.value - # referenced_model = to_arg.value - # raise helpers.IncompleteDefnException() - - # model_fields_metadata[field_name]['to'] = referenced_model.fullname() - # if referenced_model is not None: - # model_fields_metadata[field_name]['to'] = referenced_model.fullname() - # else: - # assert isinstance(to_arg, StrExpr) - # model_fields_metadata[field_name]['to'] = to_arg.value - - # related_name - related_name_arg = helpers.get_call_argument_by_name(ctx, 'related_name') - if related_name_arg: - if isinstance(related_name_arg, StrExpr): - model_fields_metadata[field_name]['related_name'] = related_name_arg.value - else: - model_fields_metadata[field_name]['related_name'] = outer_model.name().lower() + '_set' diff --git a/mypy_django_plugin/transformers/forms.py b/mypy_django_plugin/transformers/forms.py deleted file mode 100644 index ccd0dbd..0000000 --- a/mypy_django_plugin/transformers/forms.py +++ /dev/null @@ -1,46 +0,0 @@ -from mypy.plugin import ClassDefContext, MethodContext -from mypy.types import CallableType, Instance, NoneTyp, Type, TypeType - -from mypy_django_plugin.lib import metadata, helpers - - -def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None: - meta_node = helpers.get_nested_meta_node_for_current_class(ctx.cls.info) - if meta_node is None: - return None - meta_node.fallback_to_any = True - - -def extract_proper_type_for_get_form(ctx: MethodContext) -> Type: - object_type = ctx.type - if not isinstance(object_type, Instance): - return ctx.default_return_type - - form_class_type = helpers.get_call_argument_type_by_name(ctx, 'form_class') - if form_class_type is None or isinstance(form_class_type, NoneTyp): - # extract from specified form_class in metadata - form_class_fullname = metadata.get_django_metadata(object_type.type).get('form_class', None) - if not form_class_fullname: - return ctx.default_return_type - - return ctx.api.named_generic_type(form_class_fullname, []) - - if isinstance(form_class_type, TypeType) and isinstance(form_class_type.item, Instance): - return form_class_type.item - - if isinstance(form_class_type, CallableType) and isinstance(form_class_type.ret_type, Instance): - return form_class_type.ret_type - - return ctx.default_return_type - - -def extract_proper_type_for_get_form_class(ctx: MethodContext) -> Type: - object_type = ctx.type - if not isinstance(object_type, Instance): - return ctx.default_return_type - - form_class_fullname = metadata.get_django_metadata(object_type.type).get('form_class', None) - if not form_class_fullname: - return ctx.default_return_type - - return TypeType(ctx.api.named_generic_type(form_class_fullname, [])) diff --git a/mypy_django_plugin/transformers/init_create.py b/mypy_django_plugin/transformers/init_create.py deleted file mode 100644 index 54f4458..0000000 --- a/mypy_django_plugin/transformers/init_create.py +++ /dev/null @@ -1,194 +0,0 @@ -from typing import Dict, Optional, Set, cast - -from mypy.checker import TypeChecker -from mypy.nodes import TypeInfo, Var -from mypy.plugin import FunctionContext, MethodContext -from mypy.types import AnyType, Instance, Type, TypeOfAny - -from mypy_django_plugin.lib import metadata, fullnames, helpers - - -def extract_base_pointer_args(model: TypeInfo) -> Set[str]: - pointer_args: Set[str] = set() - for base in model.bases: - if base.type.has_base(fullnames.MODEL_CLASS_FULLNAME): - parent_name = base.type.name().lower() - pointer_args.add(f'{parent_name}_ptr') - pointer_args.add(f'{parent_name}_ptr_id') - return pointer_args - - -def redefine_and_typecheck_model_init(ctx: FunctionContext) -> Type: - assert isinstance(ctx.default_return_type, Instance) - - api = cast(TypeChecker, ctx.api) - model: TypeInfo = ctx.default_return_type.type - - expected_types = extract_expected_types(ctx, model, is_init=True) - - # order is preserved, can be used for positionals - positional_names = list(expected_types.keys()) - positional_names.remove('pk') - - visited_positionals = set() - # check positionals - for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])): - actual_pos_name = positional_names[i] - api.check_subtype(actual_pos_type, expected_types[actual_pos_name], - ctx.context, - 'Incompatible type for "{}" of "{}"'.format(actual_pos_name, - model.name()), - 'got', 'expected') - visited_positionals.add(actual_pos_name) - - # extract name of base models for _ptr - base_pointer_args = extract_base_pointer_args(model) - - # check kwargs - for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])): - if actual_name in base_pointer_args: - # parent_ptr args are not supported - continue - if actual_name in visited_positionals: - continue - if actual_name is None: - # unpacked dict as kwargs is not supported - continue - if actual_name not in expected_types: - ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name, - model.name()), - ctx.context) - continue - api.check_subtype(actual_type, expected_types[actual_name], - ctx.context, - 'Incompatible type for "{}" of "{}"'.format(actual_name, - model.name()), - 'got', 'expected') - return ctx.default_return_type - - -def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type: - api = cast(TypeChecker, ctx.api) - if isinstance(ctx.type, Instance) and len(ctx.type.args) > 0: - model_generic_arg = ctx.type.args[0] - else: - model_generic_arg = ctx.default_return_type - - if isinstance(model_generic_arg, AnyType): - return ctx.default_return_type - - model: TypeInfo = model_generic_arg.type - - # extract name of base models for _ptr - base_pointer_args = extract_base_pointer_args(model) - expected_types = extract_expected_types(ctx, model) - - for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]): - if actual_name in base_pointer_args: - # parent_ptr args are not supported - continue - if actual_name is None: - # unpacked dict as kwargs is not supported - continue - if actual_name not in expected_types: - api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name, - model.name()), - ctx.context) - continue - api.check_subtype(actual_type, expected_types[actual_name], - ctx.context, - 'Incompatible type for "{}" of "{}"'.format(actual_name, - model.name()), - 'got', 'expected') - - return ctx.default_return_type - - -def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]: - field_metadata = metadata.get_fields_metadata(model).get(field_name, {}) - if 'choices' in field_metadata: - return field_metadata['choices'] - return None - - -def extract_expected_types(ctx: FunctionContext, model: TypeInfo, - is_init: bool = False) -> Dict[str, Type]: - api = cast(TypeChecker, ctx.api) - - expected_types: Dict[str, Type] = {} - primary_key_type = helpers.extract_explicit_set_type_of_model_primary_key(model) - if not primary_key_type: - # no explicit primary key, set pk to Any and add id - primary_key_type = AnyType(TypeOfAny.special_form) - if is_init: - expected_types['id'] = helpers.make_optional(ctx.api.named_generic_type('builtins.int', [])) - else: - expected_types['id'] = ctx.api.named_generic_type('builtins.int', []) - - expected_types['pk'] = primary_key_type - - for base in model.mro: - # extract all fields for all models in MRO - for name, sym in base.names.items(): - # do not redefine special attrs - if name in {'_meta', 'pk'}: - continue - - if isinstance(sym.node, Var): - typ = sym.node.type - if typ is None or isinstance(typ, AnyType): - # types are not ready, fallback to Any - expected_types[name] = AnyType(TypeOfAny.from_unimported_type) - expected_types[name + '_id'] = AnyType(TypeOfAny.from_unimported_type) - - elif isinstance(typ, Instance): - field_type = helpers.extract_field_setter_type(typ) - if field_type is None: - continue - - if helpers.has_any_of_bases(typ.type, (fullnames.FOREIGN_KEY_FULLNAME, - fullnames.ONETOONE_FIELD_FULLNAME)): - related_primary_key_type = AnyType(TypeOfAny.implementation_artifact) - # in case it's optional, we need Instance type - referred_to_model = typ.args[1] - is_nullable = helpers.is_optional(referred_to_model) - if is_nullable: - referred_to_model = helpers.make_required(typ.args[1]) - - if isinstance(referred_to_model, Instance) and referred_to_model.type.has_base( - fullnames.MODEL_CLASS_FULLNAME): - pk_type = helpers.extract_explicit_set_type_of_model_primary_key(referred_to_model.type) - if not pk_type: - # extract set type of AutoField - autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField') - pk_type = helpers.get_private_descriptor_type(autofield_info, '_pyi_private_set_type', - is_nullable=is_nullable) - related_primary_key_type = pk_type - - if is_init: - related_primary_key_type = helpers.make_optional(related_primary_key_type) - - expected_types[name + '_id'] = related_primary_key_type - - field_metadata = metadata.get_fields_metadata(model).get(name, {}) - if field_type: - # related fields could be None in __init__ (but should be specified before save()) - if helpers.has_any_of_bases(typ.type, (fullnames.FOREIGN_KEY_FULLNAME, - fullnames.ONETOONE_FIELD_FULLNAME)) and is_init: - field_type = helpers.make_optional(field_type) - - # if primary_key=True and default specified - elif field_metadata.get('primary_key', False) and field_metadata.get('default_specified', - False): - field_type = helpers.make_optional(field_type) - - # if CharField(blank=True,...) and not nullable, then field can be None in __init__ - elif ( - helpers.has_any_of_bases(typ.type, (fullnames.CHAR_FIELD_FULLNAME,)) and is_init and - field_metadata.get('blank', False) and not field_metadata.get('null', False) - ): - field_type = helpers.make_optional(field_type) - - expected_types[name] = field_type - - return expected_types diff --git a/mypy_django_plugin/transformers/migrations.py b/mypy_django_plugin/transformers/migrations.py deleted file mode 100644 index 462178a..0000000 --- a/mypy_django_plugin/transformers/migrations.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Optional, cast - -from mypy.checker import TypeChecker -from mypy.nodes import Expression, StrExpr, TypeInfo -from mypy.plugin import MethodContext -from mypy.types import Instance, Type, TypeType - -from mypy_django_plugin.lib import helpers - - -def get_string_value_from_expr(expr: Expression) -> Optional[str]: - if isinstance(expr, StrExpr): - return expr.value - # TODO: somehow figure out other cases - return None - - -def determine_model_cls_from_string_for_migrations(ctx: MethodContext) -> Type: - app_label_expr = ctx.args[ctx.callee_arg_names.index('app_label')][0] - app_label = get_string_value_from_expr(app_label_expr) - if app_label is None: - return ctx.default_return_type - - if 'model_name' not in ctx.callee_arg_names: - return ctx.default_return_type - - model_name_expr_tuple = ctx.args[ctx.callee_arg_names.index('model_name')] - if not model_name_expr_tuple: - return ctx.default_return_type - - model_name = get_string_value_from_expr(model_name_expr_tuple[0]) - if model_name is None: - return ctx.default_return_type - - api = cast(TypeChecker, ctx.api) - model_fullname = helpers.get_model_fullname(app_label, model_name, all_modules=api.modules) - - if model_fullname is None: - return ctx.default_return_type - model_info = helpers.lookup_fully_qualified_generic(model_fullname, - all_modules=api.modules) - if model_info is None or not isinstance(model_info, TypeInfo): - return ctx.default_return_type - return TypeType(Instance(model_info, [])) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py deleted file mode 100644 index 4654f7a..0000000 --- a/mypy_django_plugin/transformers/models.py +++ /dev/null @@ -1,388 +0,0 @@ -from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Iterator, List, Optional, Tuple, cast - -import dataclasses -from mypy.nodes import ( - ARG_POS, ARG_STAR, ARG_STAR2, MDEF, Argument, CallExpr, ClassDef, Expression, IndexExpr, MemberExpr, MypyFile, - NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var, -) -from mypy.plugin import ClassDefContext -from mypy.plugins.common import add_method -from mypy.semanal import SemanticAnalyzerPass2 -from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny - -from mypy_django_plugin.lib import metadata, fullnames, helpers - - -@dataclasses.dataclass -class ModelClassInitializer(metaclass=ABCMeta): - api: SemanticAnalyzerPass2 - model_classdef: ClassDef - app_models_mapping: Optional[Dict[str, str]] = None - - @classmethod - def from_ctx(cls, ctx: ClassDefContext, app_models_mapping: Optional[Dict[str, str]]): - return cls(api=cast(SemanticAnalyzerPass2, ctx.api), - model_classdef=ctx.cls, - app_models_mapping=app_models_mapping) - - def get_meta_attribute(self, name: str) -> Optional[Expression]: - meta_node = helpers.get_nested_meta_node_for_current_class(self.model_classdef.info) - if meta_node is None: - return None - - return helpers.get_assignment_stmt_by_name(meta_node, name) - - def is_abstract_model(self) -> bool: - is_abstract_expr = self.get_meta_attribute('abstract') - if is_abstract_expr is None: - return False - return self.api.parse_bool(is_abstract_expr) - - def add_new_node_to_model_class(self, name: str, typ: Instance) -> None: - # type=: type of the variable itself - var = Var(name=name, type=typ) - # var.info: type of the object variable is bound to - var.info = self.model_classdef.info - var._fullname = self.model_classdef.info.fullname() + '.' + name - var.is_inferred = True - var.is_initialized_in_class = True - self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True) - - def model_has_name_defined(self, name: str) -> bool: - return name in self.model_classdef.info.names - - @abstractmethod - def run(self) -> None: - raise NotImplementedError() - - -def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExpr, CallExpr]]: - for field_name, field_init in helpers.iter_over_field_inits_in_class(klass): - field_info = field_init.callee.node - assert isinstance(field_info, TypeInfo) - - if helpers.has_any_of_bases(field_init.callee.node, {fullnames.FOREIGN_KEY_FULLNAME, - fullnames.ONETOONE_FIELD_FULLNAME}): - yield field_name, field_init - - -class AddReferencesToRelatedModels(ModelClassInitializer): - """ - For every - attr1 = models.ForeignKey(to=MyModel) - sets `attr1_id` attribute to the current model. - """ - - def run(self) -> None: - for field_name, field_init_expr in helpers.iter_over_field_inits_in_class(self.model_classdef): - ref_id_name = field_name + '_id' - field_info = field_init_expr.callee.node - assert isinstance(field_info, TypeInfo) - - if not self.model_has_name_defined(ref_id_name): - if helpers.has_any_of_bases(field_info, {fullnames.FOREIGN_KEY_FULLNAME, - fullnames.ONETOONE_FIELD_FULLNAME}): - self.add_new_node_to_model_class(name=ref_id_name, - typ=self.api.builtin_type('builtins.int')) - - # field_init_expr.callee.node - # - # for field_name, field_init_expr in helpers.iter_call_assignments_in_class(self.model_classdef): - # ref_id_name = field_name + '_id' - # if not self.model_has_name_defined(ref_id_name): - # field_class_info = field_init_expr.callee.node - # if not field_class_info: - # - # if not field_init_expr.callee.node: - # - # if isinstance(field_init_expr.callee.node, TypeInfo) \ - # and helpers.has_any_of_bases(field_init_expr.callee.node, - # {fullnames.FOREIGN_KEY_FULLNAME, - # fullnames.ONETOONE_FIELD_FULLNAME}): - # self.add_new_node_to_model_class(name=ref_id_name, - # typ=self.api.builtin_type('builtins.int')) - - -class InjectAnyAsBaseForNestedMeta(ModelClassInitializer): - """ - Replaces - class MyModel(models.Model): - class Meta: - pass - with - class MyModel(models.Model): - class Meta(Any): - pass - to get around incompatible Meta inner classes for different models. - """ - - def run(self) -> None: - meta_node = helpers.get_nested_meta_node_for_current_class(self.model_classdef.info) - if meta_node is None: - return None - meta_node.fallback_to_any = True - - -class AddDefaultObjectsManager(ModelClassInitializer): - def _add_new_manager(self, name: str, manager_type: Optional[Instance]) -> None: - if manager_type is None: - return None - self.add_new_node_to_model_class(name, manager_type) - - def _add_private_default_manager(self, manager_type: Optional[Instance]) -> None: - if manager_type is None: - return None - self.add_new_node_to_model_class('_default_manager', manager_type) - - def _get_existing_managers(self) -> List[Tuple[str, TypeInfo]]: - managers = [] - for base in self.model_classdef.info.mro: - for manager_name, call_expr in helpers.iter_call_assignments_in_class(base.defn): - callee_expr = call_expr.callee - if isinstance(callee_expr, IndexExpr): - callee_expr = callee_expr.analyzed.expr - - if isinstance(callee_expr, (MemberExpr, NameExpr)) \ - and isinstance(callee_expr.node, TypeInfo) \ - and callee_expr.node.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME): - managers.append((manager_name, callee_expr.node)) - return managers - - def run(self) -> None: - existing_managers = self._get_existing_managers() - if existing_managers: - first_manager_type = None - for manager_name, manager_type_info in existing_managers: - manager_type = Instance(manager_type_info, args=[Instance(self.model_classdef.info, [])]) - self._add_new_manager(name=manager_name, manager_type=manager_type) - if first_manager_type is None: - first_manager_type = manager_type - else: - if self.is_abstract_model(): - # abstract models do not need 'objects' queryset - return None - - first_manager_type = self.api.named_type_or_none(fullnames.MANAGER_CLASS_FULLNAME, - args=[Instance(self.model_classdef.info, [])]) - self._add_new_manager('objects', manager_type=first_manager_type) - - if self.is_abstract_model(): - return None - default_manager_name_expr = self.get_meta_attribute('default_manager_name') - if isinstance(default_manager_name_expr, StrExpr): - self._add_private_default_manager(self.model_classdef.info.get(default_manager_name_expr.value).type) - else: - self._add_private_default_manager(first_manager_type) - - -class AddDefaultPrimaryKey(ModelClassInitializer): - """ - Sets default integer `id` attribute, if: - * model is not abstract (abstract = False) - * there's no field with primary_key=True - """ - - def run(self) -> None: - if self.is_abstract_model(): - # abstract models cannot be instantiated, and do not need `id` attribute - return None - - for _, field_init_expr in helpers.iter_over_field_inits_in_class(self.model_classdef): - if ('primary_key' in field_init_expr.arg_names - and self.api.parse_bool(field_init_expr.args[field_init_expr.arg_names.index('primary_key')])): - break - else: - self.add_new_node_to_model_class('id', self.api.builtin_type('builtins.int')) - - -def _get_to_expr(field_init_expr) -> Expression: - if 'to' in field_init_expr.arg_names: - return field_init_expr.args[field_init_expr.arg_names.index('to')] - else: - return field_init_expr.args[0] - - -class AddRelatedManagers(ModelClassInitializer): - def _add_related_manager_variable(self, manager_name: str, related_field_type_data: Dict[str, Any]) -> None: - # add dummy related manager for use later - self.add_new_node_to_model_class(manager_name, self.api.builtin_type('builtins.object')) - - # save name in metadata for use in get_attribute_hook later - related_managers_metadata = metadata.get_related_managers_metadata(self.model_classdef.info) - related_managers_metadata[manager_name] = related_field_type_data - - def run(self) -> None: - for module_name, module_file in self.api.modules.items(): - for model_classdef in helpers.iter_over_toplevel_classes(module_file): - for field_name, field_init in helpers.iter_over_field_inits_in_class(model_classdef): - field_info = field_init.callee.node - assert isinstance(field_info, TypeInfo) - - if helpers.has_any_of_bases(field_info, fullnames.RELATED_FIELDS_CLASSES): - # try: - to_arg_expr = _get_to_expr(field_init) - if isinstance(to_arg_expr, NameExpr): - referenced_model_fullname = module_file.names[to_arg_expr.name].fullname - else: - assert isinstance(to_arg_expr, StrExpr) - value = to_arg_expr.value - if value == 'self': - # reference to the same model class - referenced_model_fullname = model_classdef.fullname - elif '.' not in value: - # reference to class in the current module - referenced_model_fullname = module_name + '.' + value - else: - referenced_model_fullname = self.app_models_mapping[value] - - # referenced_model_fullname = extract_referenced_model_fullname(field_init, - # module_file=module_file, - # all_modules=self.api.modules) - # if not referenced_model_fullname: - # raise helpers.IncompleteDefnException('Cannot parse referenced model fullname') - - # except helpers.SelfReference: - # referenced_model_fullname = model_classdef.fullname - # - # except helpers.SameFileModel as exc: - # referenced_model_fullname = module_name + '.' + exc.model_cls_name - - if self.model_classdef.fullname == referenced_model_fullname: - if 'related_name' in field_init.arg_names: - related_name_expr = field_init.args[field_init.arg_names.index('related_name')] - if not isinstance(related_name_expr, StrExpr): - # not string 'related_name=' not yet supported - continue - related_name = related_name_expr.value - if related_name == '+': - # No backwards relation is desired - continue - else: - related_name = model_classdef.name.lower() + '_set' - - # Default related_query_name to related_name - if 'related_query_name' in field_init.arg_names: - related_query_name_expr = field_init.args[field_init.arg_names.index('related_query_name')] - if isinstance(related_query_name_expr, StrExpr): - related_query_name = related_query_name_expr.value - else: - # not string 'related_query_name=' is not yet supported - related_query_name = None - # TODO: Handle defaulting to model name if related_name is not set - else: - related_query_name = related_name - - # if helpers.has_any_of_bases(field_info, {fullnames.FOREIGN_KEY_FULLNAME, - # fullnames.MANYTOMANY_FIELD_FULLNAME}): - # # as long as Model is not a Generic, one level depth is fine - # field_type_data = { - # 'manager': fullnames.RELATED_MANAGER_CLASS_FULLNAME, - # 'of': [model_classdef.info.fullname()] - # } - # else: - # field_type_data = { - # 'manager': model_classdef.info.fullname(), - # 'of': [] - # } - self.add_new_node_to_model_class(related_name, self.api.builtin_type('builtins.object')) - - # self._add_related_manager_variable(related_name, related_field_type_data=field_type_data) - - if related_query_name is not None: - # Only create related_query_name if it is a string literal - metadata.get_lookups_metadata(self.model_classdef.info)[related_query_name] = { - 'related_query_name_target': related_name - } - - -def get_related_field_type(rvalue: CallExpr, related_model_typ: TypeInfo) -> Dict[str, Any]: - if rvalue.callee.name in {'ForeignKey', 'ManyToManyField'}: - return { - 'manager': fullnames.RELATED_MANAGER_CLASS_FULLNAME, - 'of': [related_model_typ.fullname()] - } - else: - return { - 'manager': related_model_typ.fullname(), - 'of': [] - } - - -def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool: - """ Checks whether current CallExpr represents any supported RelatedField subclass""" - if isinstance(expr.callee, MemberExpr) and isinstance(expr.callee.expr, NameExpr): - module = module_file.names.get(expr.callee.expr.name) - if module \ - and module.fullname == 'django.db.models' \ - and expr.callee.name in {'ForeignKey', - 'OneToOneField', - 'ManyToManyField'}: - return True - return False - - -def extract_referenced_model_fullname(field_init_expr: CallExpr, - module_file: MypyFile, - all_modules: Dict[str, MypyFile]) -> Optional[str]: - """ Returns fullname of a Model referenced in "to=" argument of the CallExpr""" - if 'to' in field_init_expr.arg_names: - to_expr = field_init_expr.args[field_init_expr.arg_names.index('to')] - else: - to_expr = field_init_expr.args[0] - - if isinstance(to_expr, NameExpr): - return module_file.names[to_expr.name].fullname - - elif isinstance(to_expr, StrExpr): - typ_fullname = helpers.get_model_fullname_from_string(to_expr.value, all_modules) - if typ_fullname is None: - return None - return typ_fullname - - return None - - -def add_dummy_init_method(ctx: ClassDefContext) -> None: - any = AnyType(TypeOfAny.special_form) - - pos_arg = Argument(variable=Var('args', any), - type_annotation=any, initializer=None, kind=ARG_STAR) - kw_arg = Argument(variable=Var('kwargs', any), - type_annotation=any, initializer=None, kind=ARG_STAR2) - - add_method(ctx, '__init__', [pos_arg, kw_arg], NoneTyp()) - - # mark as model class - ctx.cls.info.metadata.setdefault('django', {})['generated_init'] = True - - -def add_get_set_attr_fallback_to_any(ctx: ClassDefContext): - any = AnyType(TypeOfAny.special_form) - - name_arg = Argument(variable=Var('name', any), - type_annotation=any, initializer=None, kind=ARG_POS) - add_method(ctx, '__getattr__', [name_arg], any) - - value_arg = Argument(variable=Var('value', any), - type_annotation=any, initializer=None, kind=ARG_POS) - add_method(ctx, '__setattr__', [name_arg, value_arg], any) - - -def process_model_class(ctx: ClassDefContext, - ignore_unknown_attributes: bool, - app_models_mapping: Optional[Dict[str, str]]) -> None: - initializers = [ - InjectAnyAsBaseForNestedMeta, - AddDefaultPrimaryKey, - AddReferencesToRelatedModels, - AddDefaultObjectsManager, - AddRelatedManagers, - ] - for initializer_cls in initializers: - initializer_cls.from_ctx(ctx, app_models_mapping).run() - - add_dummy_init_method(ctx) - - if ignore_unknown_attributes: - add_get_set_attr_fallback_to_any(ctx) diff --git a/mypy_django_plugin/transformers/queryset.py b/mypy_django_plugin/transformers/queryset.py deleted file mode 100644 index b452a95..0000000 --- a/mypy_django_plugin/transformers/queryset.py +++ /dev/null @@ -1,209 +0,0 @@ -from collections import OrderedDict -from typing import List, Optional, cast - -from mypy.checker import TypeChecker -from mypy.nodes import StrExpr, TypeInfo -from mypy.plugin import ( - AnalyzeTypeContext, CheckerPluginInterface, MethodContext, -) -from mypy.types import AnyType, Instance, Type, TypeOfAny - -from mypy_django_plugin.lib import helpers -from mypy_django_plugin.lib.lookups import ( - LookupException, RelatedModelNode, resolve_lookup, -) - - -def get_queryset_model_arg(ret_type: Instance) -> Type: - if ret_type.args: - return ret_type.args[0] - else: - return AnyType(TypeOfAny.implementation_artifact) - - -def extract_proper_type_for_queryset_values(ctx: MethodContext) -> Type: - object_type = ctx.type - if not isinstance(object_type, Instance): - return ctx.default_return_type - - fields_arg_expr = ctx.args[ctx.callee_arg_names.index('fields')] - if len(fields_arg_expr) == 0: - # values_list/values with no args is not yet supported, so default to Any types for field types - # It should in the future include all model fields, "extra" fields and "annotated" fields - return ctx.default_return_type - - model_arg = get_queryset_model_arg(ctx.default_return_type) - if isinstance(model_arg, Instance): - model_type_info = model_arg.type - else: - model_type_info = None - - column_types: OrderedDict[str, Type] = OrderedDict() - - # parse *fields - for field_expr in fields_arg_expr: - if isinstance(field_expr, StrExpr): - field_name = field_expr.value - # Default to any type - column_types[field_name] = AnyType(TypeOfAny.implementation_artifact) - - if model_type_info: - resolved_lookup_type = resolve_values_lookup(ctx.api, model_type_info, field_name) - if resolved_lookup_type is not None: - column_types[field_name] = resolved_lookup_type - else: - return ctx.default_return_type - - # parse **expressions - expression_arg_names = ctx.arg_names[ctx.callee_arg_names.index('expressions')] - for expression_name in expression_arg_names: - # Arbitrary additional annotation expressions are supported, but they all have type Any for now - column_types[expression_name] = AnyType(TypeOfAny.implementation_artifact) - - row_arg = helpers.make_typeddict(ctx.api, fields=column_types, - required_keys=set()) - return helpers.reparametrize_instance(ctx.default_return_type, [model_arg, row_arg]) - - -def extract_proper_type_queryset_values_list(ctx: MethodContext) -> Type: - object_type = ctx.type - if not isinstance(object_type, Instance): - return ctx.default_return_type - - ret = ctx.default_return_type - - model_arg = get_queryset_model_arg(ctx.default_return_type) - # model_arg: Union[AnyType, Type] = ret.args[0] if len(ret.args) > 0 else any_type - - column_names: List[Optional[str]] = [] - column_types: OrderedDict[str, Type] = OrderedDict() - - fields_arg_expr = ctx.args[ctx.callee_arg_names.index('fields')] - fields_param_is_specified = True - if len(fields_arg_expr) == 0: - # values_list/values with no args is not yet supported, so default to Any types for field types - # It should in the future include all model fields, "extra" fields and "annotated" fields - fields_param_is_specified = False - - if isinstance(model_arg, Instance): - model_type_info = model_arg.type - else: - model_type_info = None - - any_type = AnyType(TypeOfAny.implementation_artifact) - - # Figure out each field name passed to fields - only_strings_as_fields_expressions = True - for field_expr in fields_arg_expr: - if isinstance(field_expr, StrExpr): - field_name = field_expr.value - column_names.append(field_name) - # Default to any type - column_types[field_name] = any_type - - if model_type_info: - resolved_lookup_type = resolve_values_lookup(ctx.api, model_type_info, field_name) - if resolved_lookup_type is not None: - column_types[field_name] = resolved_lookup_type - else: - # Dynamic field names are partially supported for values_list, but not values - column_names.append(None) - only_strings_as_fields_expressions = False - - flat = helpers.parse_bool(helpers.get_call_argument_by_name(ctx, 'flat')) - named = helpers.parse_bool(helpers.get_call_argument_by_name(ctx, 'named')) - - api = cast(TypeChecker, ctx.api) - if named and flat: - api.fail("'flat' and 'named' can't be used together.", ctx.context) - return ret - - elif named: - # named=True, flat=False -> List[NamedTuple] - if fields_param_is_specified and only_strings_as_fields_expressions: - row_arg = helpers.make_named_tuple(api, fields=column_types, name="Row") - else: - # fallback to catch-all NamedTuple - row_arg = helpers.make_named_tuple(api, fields=OrderedDict(), name="Row") - - elif flat: - # named=False, flat=True -> List of elements - if len(ctx.args[0]) > 1: - api.fail("'flat' is not valid when values_list is called with more than one field.", - ctx.context) - return ctx.default_return_type - - if fields_param_is_specified and only_strings_as_fields_expressions: - # Grab first element - row_arg = column_types[column_names[0]] - else: - row_arg = any_type - - else: - # named=False, flat=False -> List[Tuple] - if fields_param_is_specified: - args = [ - # Fallback to Any if the column name is unknown (e.g. dynamic) - column_types.get(column_name, any_type) if column_name is not None else any_type - for column_name in column_names - ] - else: - args = [any_type] - row_arg = helpers.make_tuple(api, fields=args) - - new_type_args = [model_arg, row_arg] - return helpers.reparametrize_instance(ret, new_type_args) - - -def resolve_values_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, lookup: str) -> Optional[Type]: - """Resolves a values/values_list lookup if possible, to a Type.""" - try: - nodes = resolve_lookup(api, model_type_info, lookup) - except LookupException: - nodes = [] - - if not nodes: - return None - - make_optional = False - - for node in nodes: - if isinstance(node, RelatedModelNode) and node.is_nullable: - # All lookups following a relation which is nullable should be optional - make_optional = True - - node = nodes[-1] - - node_type = node.typ - if isinstance(node, RelatedModelNode): - # Related models used in values/values_list get resolved to the primary key of the related model. - # So, we lookup the pk of that model. - pk_lookup_nodes = resolve_lookup(api, node_type.type, "pk") - if not pk_lookup_nodes: - return None - node_type = pk_lookup_nodes[0].typ - if make_optional: - return helpers.make_optional(node_type) - else: - return node_type - - -def set_first_generic_param_as_default_for_second(fullname: str, ctx: AnalyzeTypeContext) -> Type: - if not ctx.type.args: - try: - return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit), - AnyType(TypeOfAny.explicit)]) - except KeyError: - # really should never happen - return AnyType(TypeOfAny.explicit) - - args = ctx.type.args - if len(args) == 1: - args = [args[0], args[0]] - - analyzed_args = [ctx.api.analyze_type(arg) for arg in args] - try: - return ctx.api.named_type(fullname, analyzed_args) - except KeyError: - # really should never happen - return AnyType(TypeOfAny.explicit) diff --git a/mypy_django_plugin/transformers/related.py b/mypy_django_plugin/transformers/related.py deleted file mode 100644 index cb0c962..0000000 --- a/mypy_django_plugin/transformers/related.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import Optional, Union - -from mypy.checkmember import AttributeContext -from mypy.nodes import TypeInfo -from mypy.types import AnyType, Instance, Type, TypeOfAny, UnionType - -from mypy_django_plugin.lib import fullnames, helpers - - -def _extract_referred_to_type_info(typ: Union[UnionType, Instance]) -> Optional[TypeInfo]: - if isinstance(typ, Instance): - return typ.type - else: - # should be Union[TYPE, None] - typ = helpers.make_required(typ) - if isinstance(typ, Instance): - return typ.type - return None - - -def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: AttributeContext) -> Type: - if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'): - return ctx.default_attr_type - - if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(fullnames.MODEL_CLASS_FULLNAME): - return ctx.default_attr_type - - field_name = ctx.context.name.split('_')[0] - sym = ctx.type.type.get(field_name) - if sym and isinstance(sym.type, Instance) and len(sym.type.args) > 0: - referred_to = sym.type.args[1] - if isinstance(referred_to, AnyType): - return AnyType(TypeOfAny.implementation_artifact) - - model_type = _extract_referred_to_type_info(referred_to) - if model_type is None: - return AnyType(TypeOfAny.implementation_artifact) - - primary_key_type = helpers.extract_primary_key_type_for_get(model_type) - if primary_key_type: - return primary_key_type - - is_nullable = helpers.is_field_nullable(ctx.type.type, field_name) - if is_nullable: - return helpers.make_optional(ctx.default_attr_type) - - return ctx.default_attr_type - - -def determine_type_of_related_manager(ctx: AttributeContext, related_manager_name: str) -> Type: - if not isinstance(ctx.type, Instance): - return ctx.default_attr_type - - related_manager_type = helpers.get_related_manager_type_from_metadata(ctx.type.type, - related_manager_name, ctx.api) - if not related_manager_type: - return ctx.default_attr_type - - return related_manager_type diff --git a/mypy_django_plugin/transformers/settings.py b/mypy_django_plugin/transformers/settings.py deleted file mode 100644 index 750d561..0000000 --- a/mypy_django_plugin/transformers/settings.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import TYPE_CHECKING, List, Optional, cast - -from mypy.checkexpr import FunctionContext -from mypy.checkmember import AttributeContext -from mypy.nodes import NameExpr, StrExpr, SymbolTableNode, TypeInfo -from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType - -from mypy_django_plugin.lib import helpers - -if TYPE_CHECKING: - from mypy.checker import TypeChecker - - -def get_setting_sym(name: str, api: 'TypeChecker', settings_modules: List[str]) -> Optional[SymbolTableNode]: - for settings_mod_name in settings_modules: - if settings_mod_name not in api.modules: - continue - - file = api.modules[settings_mod_name] - sym = file.names.get(name) - if sym is not None: - return sym - - return None - - -def get_type_of_setting(ctx: AttributeContext, setting_name: str, - settings_modules: List[str], ignore_missing_settings: bool) -> Type: - setting_sym = get_setting_sym(setting_name, ctx.api, settings_modules) - if setting_sym: - if setting_sym.type is None: - # TODO: defer till setting_sym.type is not None - return AnyType(TypeOfAny.implementation_artifact) - - return setting_sym.type - - if not ignore_missing_settings: - ctx.api.fail(f"'Settings' object has no attribute {setting_name!r}", ctx.context) - - return ctx.default_attr_type - - -def return_user_model_hook(ctx: FunctionContext, settings_modules: List[str]) -> Type: - from mypy.checker import TypeChecker - - api = cast(TypeChecker, ctx.api) - - setting_sym = get_setting_sym('AUTH_USER_MODEL', api, settings_modules) - if setting_sym is None: - return ctx.default_return_type - - setting_module_name, _, _ = setting_sym.fullname.rpartition('.') - setting_module = api.modules[setting_module_name] - - model_path = None - for name_expr, rvalue_expr in helpers.iter_over_assignments_in_class(setting_module): - if isinstance(name_expr, NameExpr) and isinstance(rvalue_expr, StrExpr): - if name_expr.name == 'AUTH_USER_MODEL': - model_path = rvalue_expr.value - break - - if not model_path: - return ctx.default_return_type - - app_label, _, model_class_name = model_path.rpartition('.') - if app_label is None: - return ctx.default_return_type - - model_fullname = helpers.get_model_fullname(app_label, model_class_name, - all_modules=api.modules) - if model_fullname is None: - api.fail(f'"{app_label}.{model_class_name}" model class is not imported so far. Try to import it ' - f'(under if TYPE_CHECKING) at the beginning of the current file', - context=ctx.context) - return ctx.default_return_type - - model_info = helpers.lookup_fully_qualified_generic(model_fullname, - all_modules=api.modules) - if model_info is None or not isinstance(model_info, TypeInfo): - return ctx.default_return_type - return TypeType(Instance(model_info, []))