From ab73d53ae52214d1dd09d41fc9b3d3de20b0ad2b Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Tue, 19 Feb 2019 00:43:27 +0300 Subject: [PATCH] add support for models defined in the same module be specified as name of class in related fields --- mypy_django_plugin/helpers.py | 8 ++++---- mypy_django_plugin/transformers/fields.py | 17 +++++------------ mypy_django_plugin/transformers/models.py | 9 ++++----- test-data/typecheck/related_fields.test | 12 +++++++++++- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index 3c4f4b8..035ad5d 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -54,9 +54,9 @@ def get_model_fullname(app_name: str, model_name: str, return None -class InvalidModelString(ValueError): - def __init__(self, model_string: str): - self.model_string = model_string +class SameFileModel(Exception): + def __init__(self, model_cls_name: str): + self.model_cls_name = model_cls_name class SelfReference(ValueError): @@ -69,7 +69,7 @@ def get_model_fullname_from_string(model_string: str, raise SelfReference() if '.' not in model_string: - raise InvalidModelString(model_string) + raise SameFileModel(model_string) app_name, model_name = model_string.split('.') return get_model_fullname(app_name, model_name, all_modules) diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index 83ae17d..cefeff8 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -8,7 +8,7 @@ from mypy_django_plugin import helpers from mypy_django_plugin.transformers.models import iter_over_assignments -def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: +def extract_referred_to_type(ctx: FunctionContext) -> Optional[Instance]: api = cast(TypeChecker, ctx.api) if 'to' not in ctx.callee_arg_names: api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}', @@ -27,6 +27,9 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: except helpers.SelfReference: model_fullname = api.tscope.classes[-1].fullname() + except helpers.SameFileModel as exc: + model_fullname = api.tscope.classes[-1].module_name + '.' + exc.model_cls_name + if model_fullname is None: return None model_info = helpers.lookup_fully_qualified_generic(model_fullname, @@ -69,19 +72,9 @@ def convert_any_to_type(typ: Type, referred_to_type: Type) -> Type: return typ -def _extract_referred_to_type(ctx: FunctionContext) -> Optional[Type]: - try: - referred_to_type = get_valid_to_value_or_none(ctx) - except helpers.InvalidModelString as exc: - ctx.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', ctx.context) - return None - - return referred_to_type - - def fill_descriptor_types_for_related_field(ctx: FunctionContext) -> Type: default_return_type = set_descriptor_types_for_field(ctx) - referred_to_type = _extract_referred_to_type(ctx) + referred_to_type = extract_referred_to_type(ctx) if referred_to_type is None: return default_return_type diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 8f5ce72..b56e3f8 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod from typing import Dict, Iterator, List, Optional, Tuple, cast import dataclasses -from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, CallExpr, ClassDef, Context, Expression, IndexExpr, \ +from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, CallExpr, ClassDef, Expression, IndexExpr, \ Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var from mypy.plugin import ClassDefContext from mypy.plugins.common import add_method @@ -173,10 +173,9 @@ class AddRelatedManagers(ModelClassInitializer): all_modules=self.api.modules) except helpers.SelfReference: ref_to_fullname = defn.fullname - except helpers.InvalidModelString as exc: - self.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', - Context(line=rvalue.line)) - return None + + except helpers.SameFileModel as exc: + ref_to_fullname = module_name + '.' + exc.model_cls_name if self.model_classdef.fullname == ref_to_fullname: related_manager_name = defn.name.lower() + '_set' diff --git a/test-data/typecheck/related_fields.test b/test-data/typecheck/related_fields.test index 0de483a..7eb21ca 100644 --- a/test-data/typecheck/related_fields.test +++ b/test-data/typecheck/related_fields.test @@ -274,4 +274,14 @@ Book2(publisher_id=1) Book2(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal[''], None]") Book2.objects.create(publisher_id=1) Book2.objects.create(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]") -[out] \ No newline at end of file +[out] + +[CASE if_model_is_defined_as_name_of_the_class_look_for_it_in_the_same_file] +from django.db import models + +class Book(models.Model): + publisher = models.ForeignKey(to='Publisher', on_delete=models.CASCADE) +class Publisher(models.Model): + pass +reveal_type(Book().publisher) # E: Revealed type is 'main.Publisher*' +[out]