diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index 7454b11..00cba52 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -50,9 +50,15 @@ class InvalidModelString(ValueError): self.model_string = model_string -def get_model_fullname_from_string(expr: StrExpr, +class SelfReference(ValueError): + pass + + +def get_model_fullname_from_string(model_string: str, all_modules: Dict[str, MypyFile]) -> Optional[str]: - model_string = expr.value + if model_string == 'self': + raise SelfReference() + if '.' not in model_string: raise InvalidModelString(model_string) diff --git a/mypy_django_plugin/plugins/models.py b/mypy_django_plugin/plugins/models.py index 438e006..547d58e 100644 --- a/mypy_django_plugin/plugins/models.py +++ b/mypy_django_plugin/plugins/models.py @@ -132,6 +132,8 @@ class AddRelatedManagers(ModelClassInitializer): ref_to_fullname = extract_ref_to_fullname(rvalue, module_file=module_file, all_modules=self.api.modules) + except helpers.SelfReference: + ref_to_fullname = defn.fullname except helpers.InvalidModelString as exc: self.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', Context(line=rvalue.line)) @@ -183,7 +185,7 @@ def extract_ref_to_fullname(rvalue_expr: CallExpr, if isinstance(to_expr, NameExpr): return module_file.names[to_expr.name].fullname elif isinstance(to_expr, StrExpr): - typ_fullname = helpers.get_model_fullname_from_string(to_expr, all_modules) + typ_fullname = helpers.get_model_fullname_from_string(to_expr.value, all_modules) if typ_fullname is None: return None return typ_fullname diff --git a/mypy_django_plugin/plugins/related_fields.py b/mypy_django_plugin/plugins/related_fields.py index afdcf85..5709b41 100644 --- a/mypy_django_plugin/plugins/related_fields.py +++ b/mypy_django_plugin/plugins/related_fields.py @@ -2,7 +2,7 @@ import typing from typing import Optional, cast from mypy.checker import TypeChecker -from mypy.nodes import StrExpr, TypeInfo, Context +from mypy.nodes import StrExpr, TypeInfo from mypy.plugin import FunctionContext from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny @@ -31,8 +31,12 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: if not isinstance(to_arg_expr, StrExpr): # not string, not supported return None - model_fullname = helpers.get_model_fullname_from_string(to_arg_expr, - all_modules=api.modules) + try: + model_fullname = helpers.get_model_fullname_from_string(to_arg_expr.value, + all_modules=api.modules) + except helpers.SelfReference: + model_fullname = api.tscope.classes[-1].fullname() + if model_fullname is None: return None model_info = helpers.lookup_fully_qualified_generic(model_fullname, diff --git a/test-data/typecheck/related_fields.test b/test-data/typecheck/related_fields.test index 40b4112..237415b 100644 --- a/test-data/typecheck/related_fields.test +++ b/test-data/typecheck/related_fields.test @@ -213,4 +213,18 @@ reveal_type(Member().apps) # E: Revealed type is 'django.db.models.manager.Rela from django.db import models class App(models.Model): pass -[out] \ No newline at end of file +[out] + +[CASE foreign_key_with_self] +from django.db import models +class User(models.Model): + parent = models.ForeignKey('self', on_delete=models.CASCADE) +reveal_type(User().parent) # E: Revealed type is 'main.User*' +[out] + +[CASE many_to_many_with_self] +from django.db import models +class User(models.Model): + friends = models.ManyToManyField('self', on_delete=models.CASCADE) +reveal_type(User().friends) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.User*]' +[out]