add support for models defined in the same module be specified as name of class in related fields

This commit is contained in:
Maxim Kurnikov
2019-02-19 00:43:27 +03:00
parent d24be4b35f
commit ab73d53ae5
4 changed files with 24 additions and 22 deletions

View File

@@ -54,9 +54,9 @@ def get_model_fullname(app_name: str, model_name: str,
return None return None
class InvalidModelString(ValueError): class SameFileModel(Exception):
def __init__(self, model_string: str): def __init__(self, model_cls_name: str):
self.model_string = model_string self.model_cls_name = model_cls_name
class SelfReference(ValueError): class SelfReference(ValueError):
@@ -69,7 +69,7 @@ def get_model_fullname_from_string(model_string: str,
raise SelfReference() raise SelfReference()
if '.' not in model_string: if '.' not in model_string:
raise InvalidModelString(model_string) raise SameFileModel(model_string)
app_name, model_name = model_string.split('.') app_name, model_name = model_string.split('.')
return get_model_fullname(app_name, model_name, all_modules) return get_model_fullname(app_name, model_name, all_modules)

View File

@@ -8,7 +8,7 @@ from mypy_django_plugin import helpers
from mypy_django_plugin.transformers.models import iter_over_assignments from mypy_django_plugin.transformers.models import iter_over_assignments
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: def extract_referred_to_type(ctx: FunctionContext) -> Optional[Instance]:
api = cast(TypeChecker, ctx.api) api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.callee_arg_names: if 'to' not in ctx.callee_arg_names:
api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}', api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
@@ -27,6 +27,9 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
except helpers.SelfReference: except helpers.SelfReference:
model_fullname = api.tscope.classes[-1].fullname() model_fullname = api.tscope.classes[-1].fullname()
except helpers.SameFileModel as exc:
model_fullname = api.tscope.classes[-1].module_name + '.' + exc.model_cls_name
if model_fullname is None: if model_fullname is None:
return None return None
model_info = helpers.lookup_fully_qualified_generic(model_fullname, model_info = helpers.lookup_fully_qualified_generic(model_fullname,
@@ -69,19 +72,9 @@ def convert_any_to_type(typ: Type, referred_to_type: Type) -> Type:
return typ return typ
def _extract_referred_to_type(ctx: FunctionContext) -> Optional[Type]:
try:
referred_to_type = get_valid_to_value_or_none(ctx)
except helpers.InvalidModelString as exc:
ctx.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', ctx.context)
return None
return referred_to_type
def fill_descriptor_types_for_related_field(ctx: FunctionContext) -> Type: def fill_descriptor_types_for_related_field(ctx: FunctionContext) -> Type:
default_return_type = set_descriptor_types_for_field(ctx) default_return_type = set_descriptor_types_for_field(ctx)
referred_to_type = _extract_referred_to_type(ctx) referred_to_type = extract_referred_to_type(ctx)
if referred_to_type is None: if referred_to_type is None:
return default_return_type return default_return_type

View File

@@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod
from typing import Dict, Iterator, List, Optional, Tuple, cast from typing import Dict, Iterator, List, Optional, Tuple, cast
import dataclasses import dataclasses
from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, CallExpr, ClassDef, Context, Expression, IndexExpr, \ from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, CallExpr, ClassDef, Expression, IndexExpr, \
Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext from mypy.plugin import ClassDefContext
from mypy.plugins.common import add_method from mypy.plugins.common import add_method
@@ -173,10 +173,9 @@ class AddRelatedManagers(ModelClassInitializer):
all_modules=self.api.modules) all_modules=self.api.modules)
except helpers.SelfReference: except helpers.SelfReference:
ref_to_fullname = defn.fullname ref_to_fullname = defn.fullname
except helpers.InvalidModelString as exc:
self.api.fail(f'Invalid value for a to= parameter: {exc.model_string!r}', except helpers.SameFileModel as exc:
Context(line=rvalue.line)) ref_to_fullname = module_name + '.' + exc.model_cls_name
return None
if self.model_classdef.fullname == ref_to_fullname: if self.model_classdef.fullname == ref_to_fullname:
related_manager_name = defn.name.lower() + '_set' related_manager_name = defn.name.lower() + '_set'

View File

@@ -274,4 +274,14 @@ Book2(publisher_id=1)
Book2(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal[''], None]") Book2(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal[''], None]")
Book2.objects.create(publisher_id=1) Book2.objects.create(publisher_id=1)
Book2.objects.create(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]") Book2.objects.create(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]")
[out] [out]
[CASE if_model_is_defined_as_name_of_the_class_look_for_it_in_the_same_file]
from django.db import models
class Book(models.Model):
publisher = models.ForeignKey(to='Publisher', on_delete=models.CASCADE)
class Publisher(models.Model):
pass
reveal_type(Book().publisher) # E: Revealed type is 'main.Publisher*'
[out]