diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index 0f003e1..693ad49 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -1,7 +1,7 @@ import typing from typing import Dict, Optional, NamedTuple -from mypy.nodes import SymbolTableNode, Var, Expression +from mypy.nodes import SymbolTableNode, Var, Expression, StrExpr, MypyFile, TypeInfo from mypy.plugin import FunctionContext from mypy.types import Type, Instance, UnionType, NoneTyp @@ -67,3 +67,23 @@ def make_required(typ: Type) -> Type: def get_obj_type_name(typ: typing.Type) -> str: return typ.__module__ + '.' + typ.__qualname__ + + +def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]: + models_module = '.'.join([app_name, 'models']) + return all_modules.get(models_module) + + +def get_model_type_from_string(expr: StrExpr, + all_modules: Dict[str, MypyFile]) -> Optional[TypeInfo]: + app_name, model_name = expr.value.split('.') + + models_file = get_models_file(app_name, all_modules) + if models_file is None: + # not imported so far, not supported + return None + sym = models_file.names.get(model_name) + if not sym or not isinstance(sym.node, TypeInfo): + # no such model found in the app / node is not a class definition + return None + return sym.node diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index e1fe731..88151a1 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -1,27 +1,81 @@ import os -from typing import Callable, Optional +from typing import Callable, Optional, cast +from mypy.nodes import AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr from mypy.options import Options from mypy.plugin import Plugin, FunctionContext, ClassDefContext -from mypy.types import Type +from mypy.semanal import SemanticAnalyzerPass2 +from mypy.types import Type, Instance from mypy_django_plugin import helpers, monkeypatch +from mypy_django_plugin.plugins.meta_inner_class import inject_any_as_base_for_nested_class_meta from mypy_django_plugin.plugins.objects_queryset import set_objects_queryset_to_model_class -from mypy_django_plugin.plugins.postgres_fields import determine_type_of_array_field -from mypy_django_plugin.plugins.related_fields import OneToOneFieldHook, \ - ForeignKeyHook, set_fieldname_attrs_for_related_fields +from mypy_django_plugin.plugins.fields import determine_type_of_array_field, \ + add_int_id_attribute_if_primary_key_true_is_not_present +from mypy_django_plugin.plugins.related_fields import set_fieldname_attrs_for_related_fields, add_new_var_node_to_class, \ + extract_to_parameter_as_get_ret_type from mypy_django_plugin.plugins.setup_settings import DjangoConfSettingsInitializerHook base_model_classes = {helpers.MODEL_CLASS_FULLNAME} +def add_related_managers_from_referred_foreign_keys_to_model(ctx: ClassDefContext) -> None: + api = cast(SemanticAnalyzerPass2, ctx.api) + for stmt in ctx.cls.defs.body: + if not isinstance(stmt, AssignmentStmt): + continue + if len(stmt.lvalues) > 1: + # not supported yet + continue + rvalue = stmt.rvalue + if not isinstance(rvalue, CallExpr): + continue + if (not isinstance(rvalue.callee, MemberExpr) + or not rvalue.callee.fullname in {helpers.FOREIGN_KEY_FULLNAME, + helpers.ONETOONE_FIELD_FULLNAME}): + continue + if 'related_name' not in rvalue.arg_names: + # positional related_name is not supported yet + continue + related_name = rvalue.args[rvalue.arg_names.index('related_name')].value + + if 'to' in rvalue.arg_names: + expr = rvalue.args[rvalue.arg_names.index('to')] + else: + # first positional argument + expr = rvalue.args[0] + + if isinstance(expr, StrExpr): + model_typeinfo = helpers.get_model_type_from_string(expr, + all_modules=api.modules) + if model_typeinfo is None: + continue + elif isinstance(expr, NameExpr): + model_typeinfo = expr.node + else: + continue + + if rvalue.callee.fullname == helpers.FOREIGN_KEY_FULLNAME: + typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, + args=[Instance(ctx.cls.info, [])]) + else: + typ = Instance(ctx.cls.info, []) + + if typ is None: + continue + add_new_var_node_to_class(model_typeinfo, related_name, typ) + + class TransformModelClassHook(object): def __call__(self, ctx: ClassDefContext) -> None: base_model_classes.add(ctx.cls.fullname) set_fieldname_attrs_for_related_fields(ctx) set_objects_queryset_to_model_class(ctx) + inject_any_as_base_for_nested_class_meta(ctx) + add_related_managers_from_referred_foreign_keys_to_model(ctx) + add_int_id_attribute_if_primary_key_true_is_not_present(ctx) class DjangoPlugin(Plugin): @@ -29,22 +83,25 @@ class DjangoPlugin(Plugin): options: Options) -> None: super().__init__(options) monkeypatch.replace_apply_function_plugin_method() + monkeypatch.make_inner_classes_with_inherit_from_any_compatible_with_each_other() self.django_settings = os.environ.get('DJANGO_SETTINGS_MODULE') if self.django_settings: monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(self.django_settings) monkeypatch.inject_dependencies(self.django_settings) + # monkeypatch.process_settings_before_dependants(self.django_settings) else: monkeypatch.restore_original_load_graph() monkeypatch.restore_original_dependencies_handling() def get_function_hook(self, fullname: str ) -> Optional[Callable[[FunctionContext], Type]]: - if fullname == helpers.FOREIGN_KEY_FULLNAME: - return ForeignKeyHook(settings=self.django_settings) + if fullname in {helpers.FOREIGN_KEY_FULLNAME, + helpers.ONETOONE_FIELD_FULLNAME}: + return extract_to_parameter_as_get_ret_type - if fullname == helpers.ONETOONE_FIELD_FULLNAME: - return OneToOneFieldHook(settings=self.django_settings) + # if fullname == helpers.ONETOONE_FIELD_FULLNAME: + # return OneToOneFieldHook(settings=self.django_settings) if fullname == 'django.contrib.postgres.fields.array.ArrayField': return determine_type_of_array_field diff --git a/mypy_django_plugin/monkeypatch/__init__.py b/mypy_django_plugin/monkeypatch/__init__.py index 60c7ea4..ffeb914 100644 --- a/mypy_django_plugin/monkeypatch/__init__.py +++ b/mypy_django_plugin/monkeypatch/__init__.py @@ -1,5 +1,7 @@ from .dependencies import (load_graph_to_add_settings_file_as_a_source_seed, inject_dependencies, restore_original_load_graph, - restore_original_dependencies_handling) -from .contexts import replace_apply_function_plugin_method \ No newline at end of file + restore_original_dependencies_handling, + process_settings_before_dependants) +from .contexts import replace_apply_function_plugin_method +from .multiple_inheritance import make_inner_classes_with_inherit_from_any_compatible_with_each_other \ No newline at end of file diff --git a/mypy_django_plugin/monkeypatch/dependencies.py b/mypy_django_plugin/monkeypatch/dependencies.py index 5044e8c..3a169d6 100644 --- a/mypy_django_plugin/monkeypatch/dependencies.py +++ b/mypy_django_plugin/monkeypatch/dependencies.py @@ -1,6 +1,6 @@ -from typing import List, Optional +from typing import List, Optional, AbstractSet, MutableSet, Set -from mypy.build import BuildManager, Graph, State +from mypy.build import BuildManager, Graph, State, PRI_ALL from mypy.modulefinder import BuildSource @@ -12,6 +12,7 @@ from mypy import build old_load_graph = build.load_graph OldState = build.State +old_sorted_components = build.sorted_components def load_graph_to_add_settings_file_as_a_source_seed(settings_module: str): @@ -50,3 +51,40 @@ def restore_original_dependencies_handling(): from mypy import build build.State = OldState + + +def _extract_dependencies(graph: Graph, state_id: str, visited_modules: Set[str]) -> Set[str]: + visited_modules.add(state_id) + dependencies = set(graph[state_id].dependencies) + for new_dep_id in dependencies.copy(): + if new_dep_id not in visited_modules: + dependencies.update(_extract_dependencies(graph, new_dep_id, visited_modules)) + return dependencies + + +def extract_module_dependencies(graph: Graph, state_id: str) -> Set[str]: + visited_modules = set() + return _extract_dependencies(graph, state_id, visited_modules=visited_modules) + + +def process_settings_before_dependants(settings_module: str): + def patched_sorted_components(graph: Graph, + vertices: Optional[AbstractSet[str]] = None, + pri_max: int = PRI_ALL) -> List[AbstractSet[str]]: + sccs = old_sorted_components(graph, + vertices=vertices, + pri_max=pri_max) + for i, scc in enumerate(sccs.copy()): + if 'django.conf' in scc: + django_conf_deps = set(extract_module_dependencies(graph, 'django.conf')).union({'django.conf'}) + old_scc_modified = scc.difference(django_conf_deps) + new_scc = scc.difference(old_scc_modified) + if not old_scc_modified: + # already processed + break + sccs[i] = frozenset(old_scc_modified) + sccs.insert(i, frozenset(new_scc)) + break + return sccs + + build.sorted_components = patched_sorted_components diff --git a/mypy_django_plugin/monkeypatch/multiple_inheritance.py b/mypy_django_plugin/monkeypatch/multiple_inheritance.py new file mode 100644 index 0000000..6bb9d7e --- /dev/null +++ b/mypy_django_plugin/monkeypatch/multiple_inheritance.py @@ -0,0 +1,74 @@ +from typing import Optional + +from mypy.checkmember import bind_self, is_final_node, type_object_type +from mypy.nodes import TypeInfo, Context, SymbolTableNode, FuncBase +from mypy.subtypes import is_subtype, is_equivalent +from mypy.types import FunctionLike, CallableType, Type + + +def make_inner_classes_with_inherit_from_any_compatible_with_each_other(): + from mypy.checker import TypeChecker + + def determine_type_of_class_member(self, sym: SymbolTableNode) -> Optional[Type]: + if sym.type is not None: + return sym.type + if isinstance(sym.node, FuncBase): + return self.function_type(sym.node) + if isinstance(sym.node, TypeInfo): + # nested class + return type_object_type(sym.node, self.named_type) + return None + + TypeChecker.determine_type_of_class_member = determine_type_of_class_member + + def check_compatibility(self, name: str, base1: TypeInfo, + base2: TypeInfo, ctx: Context) -> None: + """Check if attribute name in base1 is compatible with base2 in multiple inheritance. + Assume base1 comes before base2 in the MRO, and that base1 and base2 don't have + a direct subclass relationship (i.e., the compatibility requirement only derives from + multiple inheritance). + """ + if name in ('__init__', '__new__', '__init_subclass__'): + # __init__ and friends can be incompatible -- it's a special case. + return + first = base1[name] + second = base2[name] + first_type = self.determine_type_of_class_member(first) + second_type = self.determine_type_of_class_member(second) + + # TODO: What if some classes are generic? + if (isinstance(first_type, FunctionLike) and + isinstance(second_type, FunctionLike)): + if ((isinstance(first_type, CallableType) + and first_type.fallback.type.fullname() == 'builtins.type') + and (isinstance(second_type, CallableType) + and second_type.fallback.type.fullname() == 'builtins.type')): + # Both members are classes (not necessary nested), check if compatible + ok = is_subtype(first_type.ret_type, second_type.ret_type) + else: + # Method override + first_sig = bind_self(first_type) + second_sig = bind_self(second_type) + ok = is_subtype(first_sig, second_sig, ignore_pos_arg_names=True) + elif first_type and second_type: + ok = is_equivalent(first_type, second_type) + else: + if first_type is None: + self.msg.cannot_determine_type_in_base(name, base1.name(), ctx) + if second_type is None: + self.msg.cannot_determine_type_in_base(name, base2.name(), ctx) + ok = True + # Final attributes can never be overridden, but can override + # non-final read-only attributes. + if is_final_node(second.node): + self.msg.cant_override_final(name, base2.name(), ctx) + if is_final_node(first.node): + self.check_no_writable(name, second.node, ctx) + # __slots__ is special and the type can vary across class hierarchy. + if name == '__slots__': + ok = True + if not ok: + self.msg.base_class_definitions_incompatible(name, base1, base2, + ctx) + + TypeChecker.check_compatibility = check_compatibility diff --git a/mypy_django_plugin/plugins/fields.py b/mypy_django_plugin/plugins/fields.py new file mode 100644 index 0000000..836d86f --- /dev/null +++ b/mypy_django_plugin/plugins/fields.py @@ -0,0 +1,41 @@ +from typing import Iterator, List, cast + +from mypy.nodes import ClassDef, AssignmentStmt, CallExpr +from mypy.plugin import FunctionContext, ClassDefContext +from mypy.semanal import SemanticAnalyzerPass2 +from mypy.types import Type, Instance + +from mypy_django_plugin.plugins.related_fields import add_new_var_node_to_class + + +def determine_type_of_array_field(ctx: FunctionContext) -> Type: + if 'base_field' not in ctx.arg_names: + return ctx.default_return_type + + base_field_arg_type = ctx.arg_types[ctx.arg_names.index('base_field')][0] + return ctx.api.named_generic_type(ctx.context.callee.fullname, + args=[base_field_arg_type.type.names['__get__'].type.ret_type]) + + +def get_assignments(klass: ClassDef) -> List[AssignmentStmt]: + stmts = [] + for stmt in klass.defs.body: + if not isinstance(stmt, AssignmentStmt): + continue + if len(stmt.lvalues) > 1: + # not supported yet + continue + stmts.append(stmt) + return stmts + + +def add_int_id_attribute_if_primary_key_true_is_not_present(ctx: ClassDefContext) -> None: + api = cast(SemanticAnalyzerPass2, ctx.api) + for stmt in get_assignments(ctx.cls): + if (isinstance(stmt.rvalue, CallExpr) + and 'primary_key' in stmt.rvalue.arg_names + and api.parse_bool(stmt.rvalue.args[stmt.rvalue.arg_names.index('primary_key')])): + break + else: + add_new_var_node_to_class(ctx.cls.info, 'id', api.builtin_type('builtins.int')) + diff --git a/mypy_django_plugin/plugins/meta_inner_class.py b/mypy_django_plugin/plugins/meta_inner_class.py new file mode 100644 index 0000000..b37b159 --- /dev/null +++ b/mypy_django_plugin/plugins/meta_inner_class.py @@ -0,0 +1,12 @@ +from mypy.nodes import TypeInfo +from mypy.plugin import ClassDefContext + + +def inject_any_as_base_for_nested_class_meta(ctx: ClassDefContext) -> None: + if 'Meta' not in ctx.cls.info.names: + return None + sym = ctx.cls.info.names['Meta'] + if not isinstance(sym.node, TypeInfo): + return None + + sym.node.fallback_to_any = True diff --git a/mypy_django_plugin/plugins/postgres_fields.py b/mypy_django_plugin/plugins/postgres_fields.py deleted file mode 100644 index 25e4171..0000000 --- a/mypy_django_plugin/plugins/postgres_fields.py +++ /dev/null @@ -1,11 +0,0 @@ -from mypy.plugin import FunctionContext -from mypy.types import Type - - -def determine_type_of_array_field(ctx: FunctionContext) -> Type: - if 'base_field' not in ctx.arg_names: - return ctx.default_return_type - - base_field_arg_type = ctx.arg_types[ctx.arg_names.index('base_field')][0] - return ctx.api.named_generic_type(ctx.context.callee.fullname, - args=[base_field_arg_type.type.names['__get__'].type.ret_type]) diff --git a/mypy_django_plugin/plugins/related_fields.py b/mypy_django_plugin/plugins/related_fields.py index 7458c2b..aac0791 100644 --- a/mypy_django_plugin/plugins/related_fields.py +++ b/mypy_django_plugin/plugins/related_fields.py @@ -3,15 +3,16 @@ from typing import Optional, cast from django.conf import Settings from mypy.checker import TypeChecker -from mypy.nodes import SymbolTable, MDEF, AssignmentStmt +from mypy.nodes import MDEF, AssignmentStmt, MypyFile, StrExpr, TypeInfo, NameExpr, Var, SymbolTableNode from mypy.plugin import FunctionContext, ClassDefContext from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny from mypy_django_plugin import helpers +from mypy_django_plugin.helpers import get_models_file def extract_related_name_value(ctx: FunctionContext) -> str: - return ctx.context.args[ctx.arg_names.index('related_name')].value + return ctx.args[ctx.arg_names.index('related_name')][0].value def reparametrize_with(instance: Instance, new_typevars: typing.List[Type]): @@ -31,8 +32,15 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: arg_type = ctx.arg_types[ctx.arg_names.index('to')][0] if not isinstance(arg_type, CallableType): - # to= defined as string is not supported - return None + to_arg_expr = ctx.args[ctx.arg_names.index('to')][0] + if not isinstance(to_arg_expr, StrExpr): + # not string, not supported + return None + model_info = helpers.get_model_type_from_string(to_arg_expr, + all_modules=cast(TypeChecker, ctx.api).modules) + if model_info is None: + return None + return Instance(model_info, []) referred_to_type = arg_type.ret_type for base in referred_to_type.type.bases: @@ -47,59 +55,38 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: return referred_to_type -class ForeignKeyHook(object): - def __init__(self, settings: Settings): - self.settings = settings - - def __call__(self, ctx: FunctionContext) -> Type: - api = cast(TypeChecker, ctx.api) - outer_class_info = api.tscope.classes[-1] - - referred_to_type = get_valid_to_value_or_none(ctx) - if referred_to_type is None: - return fill_typevars_with_any(ctx.default_return_type) - - if 'related_name' in ctx.arg_names: - related_name = extract_related_name_value(ctx) - queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME, - args=[Instance(outer_class_info, [])]) - sym = helpers.create_new_symtable_node(related_name, MDEF, - instance=queryset_type) - referred_to_type.type.names[related_name] = sym - - return reparametrize_with(ctx.default_return_type, [referred_to_type]) +def add_new_var_node_to_class(class_type: TypeInfo, name: str, typ: Instance) -> None: + var = Var(name=name, type=typ) + var.info = typ.type + var._fullname = class_type.fullname() + '.' + name + var.is_inferred = True + var.is_initialized_in_class = True + class_type.names[name] = SymbolTableNode(MDEF, var) -class OneToOneFieldHook(object): - def __init__(self, settings: Optional[Settings]): - self.settings = settings - - def __call__(self, ctx: FunctionContext) -> Type: - api = cast(TypeChecker, ctx.api) - outer_class_info = api.tscope.classes[-1] - - referred_to_type = get_valid_to_value_or_none(ctx) - if referred_to_type is None: - return fill_typevars_with_any(ctx.default_return_type) - - if 'related_name' in ctx.arg_names: - related_name = extract_related_name_value(ctx) - sym = helpers.create_new_symtable_node(related_name, MDEF, - instance=Instance(outer_class_info, [])) - referred_to_type.type.names[related_name] = sym - - return reparametrize_with(ctx.default_return_type, [referred_to_type]) +def extract_to_parameter_as_get_ret_type(ctx: FunctionContext) -> Type: + referred_to_type = get_valid_to_value_or_none(ctx) + if referred_to_type is None: + # couldn't extract to= value + return fill_typevars_with_any(ctx.default_return_type) + return reparametrize_with(ctx.default_return_type, [referred_to_type]) def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None: api = ctx.api - - new_symtable_nodes = SymbolTable() - for (name, symtable_node), stmt in zip(ctx.cls.info.names.items(), ctx.cls.defs.body): + for stmt in ctx.cls.defs.body: if not isinstance(stmt, AssignmentStmt): continue if not hasattr(stmt.rvalue, 'callee'): continue + if len(stmt.lvalues) > 1: + # multiple lvalues not supported for now + continue + + expr = stmt.lvalues[0] + if not isinstance(expr, NameExpr): + continue + name = expr.name rvalue_callee = stmt.rvalue.callee if rvalue_callee.fullname in {helpers.FOREIGN_KEY_FULLNAME, @@ -108,7 +95,4 @@ def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None: new_node = helpers.create_new_symtable_node(name, kind=MDEF, instance=api.named_type('__builtins__.int')) - new_symtable_nodes[name] = new_node - - for name, node in new_symtable_nodes.items(): - ctx.cls.info.names[name] = node + ctx.cls.info.names[name] = new_node diff --git a/pytest.ini b/pytest.ini index 3f271e9..4291613 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,7 +4,5 @@ testpaths = ./test python_files = test*.py addopts = --tb=native - -v - -s --ignore=./external - --mypy-ini-file=./test/plugins.ini + --mypy-ini-file=./test/plugins.ini \ No newline at end of file diff --git a/test/test-data/model-fields.test b/test/test-data/model-fields.test index 6384a60..09dfa1f 100644 --- a/test/test-data/model-fields.test +++ b/test/test-data/model-fields.test @@ -14,3 +14,47 @@ reveal_type(user.small_int) # E: Revealed type is 'builtins.int' reveal_type(user.name) # E: Revealed type is 'builtins.str' reveal_type(user.slug) # E: Revealed type is 'builtins.str' reveal_type(user.text) # E: Revealed type is 'builtins.str' + +[CASE test_add_id_field_if_no_primary_key_defined] +from django.db import models + +class User(models.Model): + pass + +reveal_type(User().id) # E: Revealed type is 'builtins.int' + +[CASE test_do_not_add_id_if_field_with_primary_key_True_defined] +from django.db import models + +class User(models.Model): + my_pk = models.IntegerField(primary_key=True) + +reveal_type(User().my_pk) # E: Revealed type is 'builtins.int' +reveal_type(User().id) # E: Revealed type is 'Any' +[out] +main:7: error: "User" has no attribute "id" + +[CASE test_meta_nested_class_allows_subclassing_in_multiple_inheritance] +from typing import Any +from django.db import models + +class Mixin1(models.Model): + class Meta: + abstract = True + +class Mixin2(models.Model): + class Meta: + abstract = True + +class User(Mixin1, Mixin2): + pass +[out] + +[CASE test_inheritance_from_abstract_model_does_not_fail_if_field_with_id_exists] +from django.db import models +class Abstract(models.Model): + class Meta: + abstract = True +class User(Abstract): + id = models.AutoField(primary_key=True) +[out] diff --git a/test/test-data/model-relations.test b/test/test-data/model-relations.test index 40c0ced..abe9544 100644 --- a/test/test-data/model-relations.test +++ b/test/test-data/model-relations.test @@ -21,11 +21,14 @@ class Publisher(models.Model): pass class Book(models.Model): - publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE, - related_name='books') + publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) + class StylesheetError(Exception): + pass + owner = models.ForeignKey(db_column='model_id', to='db.Unknown', on_delete=models.CASCADE) book = Book() reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int' +reveal_type(book.owner_id) # E: Revealed type is 'builtins.int' [CASE test_foreign_key_field_different_order_of_params] from django.db import models @@ -36,18 +39,77 @@ class Publisher(models.Model): class Book(models.Model): publisher = models.ForeignKey(on_delete=models.CASCADE, to=Publisher, related_name='books') + publisher2 = models.ForeignKey(to=Publisher, related_name='books2', on_delete=models.CASCADE) book = Book() reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*' +reveal_type(book.publisher2) # E: Revealed type is 'main.Publisher*' publisher = Publisher() reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]' +reveal_type(publisher.books2) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]' -[CASE test_to_parameter_as_string_fallbacks_to_any] +[CASE test_to_parameter_as_string_with_application_name__model_imported] +from django.db import models +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from myapp.models import Publisher + +class Book(models.Model): + publisher = models.ForeignKey(to='myapp.Publisher', on_delete=models.CASCADE) + +book = Book() +reveal_type(book.publisher) # E: Revealed type is 'myapp.models.Publisher*' + +[file myapp/__init__.py] +[file myapp/models.py] +from django.db import models +class Publisher(models.Model): + pass + +[CASE test_to_parameter_as_string_with_application_name__fallbacks_to_any_if_model_not_present_in_dependency_graph] from django.db import models class Book(models.Model): - publisher = models.ForeignKey(to='Publisher', on_delete=models.CASCADE) + publisher = models.ForeignKey(to='myapp.Publisher', on_delete=models.CASCADE) book = Book() reveal_type(book.publisher) # E: Revealed type is 'Any' + +[file myapp/__init__.py] +[file myapp/models.py] +from django.db import models +class Publisher(models.Model): + pass + +[CASE test_circular_dependency_in_imports_with_foreign_key] +from django.db import models + +class App(models.Model): + def method(self) -> None: + reveal_type(self.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' + reveal_type(self.members) # E: Revealed type is 'django.db.models.query.QuerySet[main.Member]' + reveal_type(self.sheets) # E: Revealed type is 'django.db.models.query.QuerySet[main.Sheet]' + reveal_type(self.profile) # E: Revealed type is 'main.Profile' +class View(models.Model): + app = models.ForeignKey(to=App, related_name='views', on_delete=models.CASCADE) +class Member(models.Model): + app = models.ForeignKey(related_name='members', on_delete=models.CASCADE, to=App) +class Sheet(models.Model): + app = models.ForeignKey(App, related_name='sheets', on_delete=models.CASCADE) +class Profile(models.Model): + app = models.OneToOneField(App, related_name='profile', on_delete=models.CASCADE) + +[CASE test_circular_dependency_in_imports_with_string_based] +from django.db import models +from myapp.models import App +class View(models.Model): + app = models.ForeignKey(to='myapp.App', related_name='views', on_delete=models.CASCADE) + +[file myapp/__init__.py] +[file myapp/models.py] +from django.db import models +class App(models.Model): + def method(self) -> None: + reveal_type(self.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' diff --git a/test/test-data/parse-settings.test b/test/test-data/parse-settings.test index 0e0b54b..8270149 100644 --- a/test/test-data/parse-settings.test +++ b/test/test-data/parse-settings.test @@ -19,11 +19,36 @@ OBJ = LazyObject() [CASE test_settings_could_be_defined_in_different_module_and_imported_with_star] from django.conf import settings -reveal_type(settings.BASE) # E: Revealed type is 'builtins.int' -reveal_type(settings.NESTED) # E: Revealed type is 'builtins.str' +reveal_type(settings.ROOT_DIR) # E: Revealed type is 'pathlib.Path' +reveal_type(settings.SETUP) # E: Revealed type is 'builtins.int' +reveal_type(settings.DATABASES) # E: Revealed type is 'builtins.dict[builtins.str, builtins.str]' [env DJANGO_SETTINGS_MODULE=mysettings] [file mysettings.py] +from local import * +DATABASES = {'default': 'mydb'} +[file local.py] from base import * -NESTED = '1122' +SETUP = 3 [file base.py] -BASE = 1 +from pathlib import Path + +ROOT_DIR = Path(__file__) + +[CASE test_circular_dependency_in_settings] +from django.conf import settings + +class Class: + pass + +reveal_type(settings.MYSETTING) # E: Revealed type is 'builtins.int' +reveal_type(settings.REGISTRY) # E: Revealed type is 'Any' + +[file mysettings.py] +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .main import Class + +MYSETTING = 1122 +REGISTRY: Optional['Class'] = None +