solve more use cases for related managers and settings

This commit is contained in:
Maxim Kurnikov
2018-12-03 01:57:46 +03:00
parent fcd659837e
commit 3676cb3ac0
9 changed files with 222 additions and 148 deletions

4
mypy.ini Normal file
View File

@@ -0,0 +1,4 @@
[mypy]
[mypy-mypy_django_plugin.monkeypatch.*]
ignore_errors = True

View File

@@ -1,9 +1,7 @@
import typing import typing
from typing import Dict, Optional, NamedTuple from typing import Dict, Optional
from mypy.nodes import Expression, StrExpr, MypyFile, TypeInfo from mypy.nodes import StrExpr, MypyFile, TypeInfo, ImportedName, SymbolNode
from mypy.plugin import FunctionContext
from mypy.types import Type, UnionType, NoneTyp
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet' QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
@@ -11,63 +9,14 @@ FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField' ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject' DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
Argument = NamedTuple('Argument', fields=[
('arg', Expression),
('arg_type', Type)
])
def get_call_signature_or_none(ctx: FunctionContext) -> Optional[Dict[str, Argument]]:
result: Dict[str, Argument] = {}
positional_args_only = []
positional_arg_types_only = []
for arg, arg_name, arg_type in zip(ctx.args, ctx.arg_names, ctx.arg_types):
if arg_name is None:
positional_args_only.append(arg)
positional_arg_types_only.append(arg_type)
continue
if len(arg) == 0 or len(arg_type) == 0:
continue
result[arg_name] = (arg[0], arg_type[0])
callee = ctx.context.callee
if '__init__' not in callee.node.names:
return None
init_type = callee.node.names['__init__'].type
arg_names = init_type.arg_names[1:]
for arg, arg_name, arg_type in zip(positional_args_only,
arg_names[:len(positional_args_only)],
positional_arg_types_only):
result[arg_name] = (arg[0], arg_type[0])
return result
def make_optional(typ: Type) -> Type:
return UnionType.make_simplified_union([typ, NoneTyp()])
def make_required(typ: Type) -> Type:
if not isinstance(typ, UnionType):
return typ
items = [item for item in typ.items if not isinstance(item, NoneTyp)]
return UnionType.make_union(items)
def get_obj_type_name(typ: typing.Type) -> str:
return typ.__module__ + '.' + typ.__qualname__
def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]: def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]:
models_module = '.'.join([app_name, 'models']) models_module = '.'.join([app_name, 'models'])
return all_modules.get(models_module) return all_modules.get(models_module)
def get_model_type_from_string(expr: StrExpr, def get_model_fullname_from_string(expr: StrExpr,
all_modules: Dict[str, MypyFile]) -> Optional[TypeInfo]: all_modules: Dict[str, MypyFile]) -> Optional[str]:
app_name, model_name = expr.value.split('.') app_name, model_name = expr.value.split('.')
models_file = get_models_file(app_name, all_modules) models_file = get_models_file(app_name, all_modules)
@@ -75,7 +24,26 @@ def get_model_type_from_string(expr: StrExpr,
# not imported so far, not supported # not imported so far, not supported
return None return None
sym = models_file.names.get(model_name) sym = models_file.names.get(model_name)
if not sym or not isinstance(sym.node, TypeInfo): if not sym:
# no such model found in the app / node is not a class definition return None
if isinstance(sym.node, TypeInfo):
return sym.node.fullname()
elif isinstance(sym.node, ImportedName):
return sym.node.target_fullname
else:
return None
def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]:
if '.' not in name:
return None
module, cls_name = name.rsplit('.', 1)
module_file = all_modules.get(module)
if module_file is None:
return None
sym = module_file.names.get(cls_name)
if sym is None:
return None return None
return sym.node return sym.node

View File

@@ -42,9 +42,6 @@ class DjangoPlugin(Plugin):
helpers.ONETOONE_FIELD_FULLNAME}: helpers.ONETOONE_FIELD_FULLNAME}:
return extract_to_parameter_as_get_ret_type_for_related_field return extract_to_parameter_as_get_ret_type_for_related_field
# if fullname == helpers.ONETOONE_FIELD_FULLNAME:
# return OneToOneFieldHook(settings=self.django_settings)
if fullname == 'django.contrib.postgres.fields.array.ArrayField': if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field return determine_type_of_array_field
return None return None

View File

@@ -1,7 +1,7 @@
from typing import cast, Iterator, Tuple, Optional from typing import cast, Iterator, Tuple, Optional, Dict
from mypy.nodes import ClassDef, AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr, MDEF, TypeInfo, Var, SymbolTableNode, \ from mypy.nodes import ClassDef, AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr, MDEF, TypeInfo, Var, SymbolTableNode, \
Lvalue, Expression, Statement Lvalue, Expression, MypyFile
from mypy.plugin import ClassDefContext from mypy.plugin import ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2 from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Instance from mypy.types import Instance
@@ -30,12 +30,11 @@ def iter_over_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, Expression]
def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]: def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]:
for lvalue, rvalue in iter_over_assignments(klass): for lvalue, rvalue in iter_over_assignments(klass):
if not isinstance(rvalue, CallExpr): if isinstance(rvalue, CallExpr):
continue yield lvalue, rvalue
yield lvalue, rvalue
def iter_over_one_to_n_related_fields(klass: ClassDef, api: SemanticAnalyzerPass2) -> Iterator[Tuple[NameExpr, CallExpr]]: def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExpr, CallExpr]]:
for lvalue, rvalue in iter_call_assignments(klass): for lvalue, rvalue in iter_call_assignments(klass):
if (isinstance(lvalue, NameExpr) if (isinstance(lvalue, NameExpr)
and isinstance(rvalue.callee, MemberExpr)): and isinstance(rvalue.callee, MemberExpr)):
@@ -67,7 +66,7 @@ def is_abstract_model(ctx: ClassDefContext) -> bool:
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None: def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
api = ctx.api api = ctx.api
for lvalue, rvalue in iter_over_one_to_n_related_fields(ctx.cls, api): for lvalue, rvalue in iter_over_one_to_n_related_fields(ctx.cls):
property_name = lvalue.name + '_id' property_name = lvalue.name + '_id'
add_new_var_node_to_class(ctx.cls.info, property_name, add_new_var_node_to_class(ctx.cls.info, property_name,
typ=api.named_type('__builtins__.int')) typ=api.named_type('__builtins__.int'))
@@ -112,68 +111,67 @@ def inject_any_as_base_for_nested_class_meta(ctx: ClassDefContext) -> None:
meta_node.fallback_to_any = True meta_node.fallback_to_any = True
def is_model_defn(defn: Statement, api: SemanticAnalyzerPass2) -> bool: def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]:
if not isinstance(defn, ClassDef): for defn in module_file.defs:
return False if isinstance(defn, ClassDef):
yield defn
for base_type_expr in defn.base_type_exprs:
# api.accept(base_type_expr)
fullname = getattr(base_type_expr, 'fullname', None)
if fullname == helpers.MODEL_CLASS_FULLNAME:
return True
return False
def iter_over_models(ctx: ClassDefContext) -> Iterator[ClassDef]:
for module_name, module_file in ctx.api.modules.items():
for defn in module_file.defs:
if is_model_defn(defn, api=cast(SemanticAnalyzerPass2, ctx.api)):
yield defn
def extract_to_value_or_none(field_expr: CallExpr, ctx: ClassDefContext) -> Optional[TypeInfo]:
if 'to' in field_expr.arg_names:
ref_expr = field_expr.args[field_expr.arg_names.index('to')]
else:
# first positional argument
ref_expr = field_expr.args[0]
if isinstance(ref_expr, StrExpr):
model_typeinfo = helpers.get_model_type_from_string(ref_expr,
all_modules=ctx.api.modules)
return model_typeinfo
elif isinstance(ref_expr, NameExpr):
return ref_expr.node
def get_related_field_type(rvalue: CallExpr, api: SemanticAnalyzerPass2, def get_related_field_type(rvalue: CallExpr, api: SemanticAnalyzerPass2,
related_model_typ: TypeInfo) -> Optional[Instance]: related_model_typ: TypeInfo) -> Optional[Instance]:
if rvalue.callee.fullname == helpers.FOREIGN_KEY_FULLNAME: if rvalue.callee.name == 'ForeignKey':
return api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, return api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME,
args=[Instance(related_model_typ, [])]) args=[Instance(related_model_typ, [])])
else: else:
return Instance(related_model_typ, []) return Instance(related_model_typ, [])
def add_related_managers(ctx: ClassDefContext) -> None: def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool:
for model_defn in iter_over_models(ctx): if isinstance(expr.callee, MemberExpr) and isinstance(expr.callee.expr, NameExpr):
for _, rvalue in iter_over_one_to_n_related_fields(model_defn, ctx.api): module = module_file.names[expr.callee.expr.name]
if 'related_name' not in rvalue.arg_names: if module.fullname == 'django.db.models' and expr.callee.name in {'ForeignKey', 'OneToOneField'}:
# positional related_name is not supported yet return True
return return False
related_name = rvalue.args[rvalue.arg_names.index('related_name')].value
ref_to_typ = extract_to_value_or_none(rvalue, ctx)
if ref_to_typ is not None: def extract_ref_to_fullname(rvalue_expr: CallExpr,
if ref_to_typ.fullname() == ctx.cls.info.fullname(): module_file: MypyFile, all_modules: Dict[str, MypyFile]) -> Optional[str]:
typ = get_related_field_type(rvalue, ctx.api, if 'to' in rvalue_expr.arg_names:
related_model_typ=model_defn.info) to_expr = rvalue_expr.args[rvalue_expr.arg_names.index('to')]
if typ is None: else:
return to_expr = rvalue_expr.args[0]
add_new_var_node_to_class(ctx.cls.info, related_name, typ) if isinstance(to_expr, NameExpr):
return module_file.names[to_expr.name].fullname
elif isinstance(to_expr, StrExpr):
typ_fullname = helpers.get_model_fullname_from_string(to_expr, all_modules)
if typ_fullname is None:
return None
return typ_fullname
return None
def add_related_managers(ctx: ClassDefContext):
api = cast(SemanticAnalyzerPass2, ctx.api)
for module_name, module_file in ctx.api.modules.items():
for defn in iter_over_classdefs(module_file):
for lvalue, rvalue in iter_call_assignments(defn):
if is_related_field(rvalue, module_file):
ref_to_fullname = extract_ref_to_fullname(rvalue, module_file=module_file,
all_modules=api.modules)
if ctx.cls.fullname == ref_to_fullname:
if 'related_name' in rvalue.arg_names:
related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')]
if not isinstance(related_name_expr, StrExpr):
return None
related_name = related_name_expr.value
typ = get_related_field_type(rvalue, api, defn.info)
if typ is None:
return None
add_new_var_node_to_class(ctx.cls.info, related_name, typ)
def process_model_class(ctx: ClassDefContext) -> None: def process_model_class(ctx: ClassDefContext) -> None:
# add_related_managers(ctx) add_related_managers(ctx)
inject_any_as_base_for_nested_class_meta(ctx) inject_any_as_base_for_nested_class_meta(ctx)
set_fieldname_attrs_for_related_fields(ctx) set_fieldname_attrs_for_related_fields(ctx)
add_int_id_attribute_if_primary_key_true_is_not_present(ctx) add_int_id_attribute_if_primary_key_true_is_not_present(ctx)

View File

@@ -2,7 +2,7 @@ import typing
from typing import Optional, cast from typing import Optional, cast
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.nodes import StrExpr from mypy.nodes import StrExpr, TypeInfo
from mypy.plugin import FunctionContext from mypy.plugin import FunctionContext
from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny
@@ -18,10 +18,11 @@ def fill_typevars_with_any(instance: Instance) -> Type:
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.arg_names: if 'to' not in ctx.arg_names:
# shouldn't happen, invalid code # shouldn't happen, invalid code
ctx.api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}', api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
context=ctx.context) context=ctx.context)
return None return None
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0] arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
@@ -30,13 +31,19 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
if not isinstance(to_arg_expr, StrExpr): if not isinstance(to_arg_expr, StrExpr):
# not string, not supported # not string, not supported
return None return None
model_info = helpers.get_model_type_from_string(to_arg_expr, model_fullname = helpers.get_model_fullname_from_string(to_arg_expr,
all_modules=cast(TypeChecker, ctx.api).modules) all_modules=api.modules)
if model_info is None: if model_fullname is None:
return None
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
all_modules=api.modules)
if model_info is None or not isinstance(model_info, TypeInfo):
return None return None
return Instance(model_info, []) return Instance(model_info, [])
referred_to_type = arg_type.ret_type referred_to_type = arg_type.ret_type
if not isinstance(referred_to_type, Instance):
return None
for base in referred_to_type.type.bases: for base in referred_to_type.type.bases:
if base.type.fullname() == helpers.MODEL_CLASS_FULLNAME: if base.type.fullname() == helpers.MODEL_CLASS_FULLNAME:
break break

View File

@@ -1,12 +1,10 @@
from typing import cast, List from typing import cast, List, Optional
from mypy.nodes import Var, Context, SymbolNode, SymbolTableNode from mypy.nodes import Var, Context, SymbolNode, SymbolTableNode
from mypy.plugin import ClassDefContext from mypy.plugin import ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2 from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Instance, UnionType, NoneTyp, Type from mypy.types import Instance, UnionType, NoneTyp, Type
from mypy_django_plugin import helpers
def get_error_context(node: SymbolNode) -> Context: def get_error_context(node: SymbolNode) -> Context:
context = Context() context = Context()
@@ -18,10 +16,24 @@ def filter_out_nones(typ: UnionType) -> List[Type]:
return [item for item in typ.items if not isinstance(item, NoneTyp)] return [item for item in typ.items if not isinstance(item, NoneTyp)]
def copy_sym_of_instance(sym: SymbolTableNode) -> SymbolTableNode: def make_sym_copy_of_setting(sym: SymbolTableNode) -> Optional[SymbolTableNode]:
copied = sym.copy() if isinstance(sym.type, Instance):
copied.node.info = sym.type.type copied = sym.copy()
return copied copied.node.info = sym.type.type
return copied
elif isinstance(sym.type, UnionType):
instances = filter_out_nones(sym.type)
if len(instances) > 1:
# plain unions not supported yet
return None
typ = instances[0]
if isinstance(typ, Instance):
copied = sym.copy()
copied.node.info = typ.type
return copied
return None
else:
return None
def add_settings_to_django_conf_object(ctx: ClassDefContext, def add_settings_to_django_conf_object(ctx: ClassDefContext,
@@ -33,25 +45,24 @@ def add_settings_to_django_conf_object(ctx: ClassDefContext,
settings_file = api.modules[settings_module] settings_file = api.modules[settings_module]
for name, sym in settings_file.names.items(): for name, sym in settings_file.names.items():
if name.isupper() and isinstance(sym.node, Var): if name.isupper() and isinstance(sym.node, Var):
if isinstance(sym.type, Instance): if sym.type is not None:
copied = sym.copy() copied = make_sym_copy_of_setting(sym)
copied.node.info = sym.type.type if copied is None:
ctx.cls.info.names[name] = copied
elif isinstance(sym.type, UnionType):
instances = filter_out_nones(sym.type)
if len(instances) > 1:
# plain unions not supported yet
continue continue
typ = instances[0] ctx.cls.info.names[name] = copied
if isinstance(typ, Instance): else:
copied = sym.copy() context = Context()
copied.node.info = typ.type module, node_name = sym.node.fullname().rsplit('.', 1)
ctx.cls.info.names[name] = copied module_file = api.modules.get(module)
if module_file is None:
return None
context.set_line(sym.node)
api.msg.report(f"Need type annotation for '{sym.node.name()}'", context,
severity='error', file=module_file.path)
class DjangoConfSettingsInitializerHook(object): class DjangoConfSettingsInitializerHook(object):
def __init__(self, settings_module: str): def __init__(self, settings_module: Optional[str]):
self.settings_module = settings_module self.settings_module = settings_module
def __call__(self, ctx: ClassDefContext) -> None: def __call__(self, ctx: ClassDefContext) -> None:

View File

@@ -5,5 +5,6 @@ python_files = test*.py
addopts = addopts =
--tb=native --tb=native
--mypy-ini-file=./test-data/plugins.ini --mypy-ini-file=./test-data/plugins.ini
--mypy-no-cache
-s -s
-v -v

View File

@@ -119,7 +119,7 @@ class App(models.Model):
def method(self) -> None: def method(self) -> None:
reveal_type(self.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' reveal_type(self.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]'
[case models_related_managers_work_with_direct_model_inheritance_and_with_inheritance_from_other_model] [CASE models_related_managers_work_with_direct_model_inheritance_and_with_inheritance_from_other_model]
from django.db.models import Model from django.db.models import Model
from django.db import models from django.db import models
@@ -134,4 +134,62 @@ class View2(View):
reveal_type(App().views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' reveal_type(App().views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]'
reveal_type(App().views2) # E: Revealed type is 'django.db.models.query.QuerySet[main.View2]' reveal_type(App().views2) # E: Revealed type is 'django.db.models.query.QuerySet[main.View2]'
[out] [out]
[CASE models_imported_inside_init_file]
from django.db import models
from myapp.models import App
class View(models.Model):
app = models.ForeignKey(to='myapp.App', related_name='views', on_delete=models.CASCADE)
reveal_type(View().app.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]'
[file myapp/__init__.py]
[file myapp/models/__init__.py]
from .app import App
[file myapp/models/app.py]
from django.db import models
class App(models.Model):
pass
[CASE models_imported_inside_init_file_one_to_one_field]
from django.db import models
from myapp.models import User
class Profile(models.Model):
user = models.OneToOneField(to='myapp.User', related_name='profile', on_delete=models.CASCADE)
reveal_type(Profile().user.profile) # E: Revealed type is 'main.Profile'
[file myapp/__init__.py]
[file myapp/models/__init__.py]
from .user import User
[file myapp/models/user.py]
from django.db import models
class User(models.Model):
pass
[CASE models_triple_circular_reference]
from myapp.models import App
reveal_type(App().owner.profile) # E: Revealed type is 'myapp.models.profile.Profile'
[file myapp/__init__.py]
[file myapp/models/__init__.py]
from .user import User
from .profile import Profile
from .app import App
[file myapp/models/user.py]
from django.db import models
class User(models.Model):
pass
[file myapp/models/profile.py]
from django.db import models
from myapp.models import User
class Profile(models.Model):
user = models.OneToOneField(to='myapp.User', related_name='profile', on_delete=models.CASCADE)
[file myapp/models/app.py]
from django.db import models
class App(models.Model):
owner = models.ForeignKey(to='myapp.User', on_delete=models.CASCADE, related_name='apps')

View File

@@ -36,17 +36,47 @@ ROOT_DIR = Path(__file__)
from django.conf import settings from django.conf import settings
class Class: class Class:
pass pass
reveal_type(settings.MYSETTING) # E: Revealed type is 'builtins.int'
reveal_type(settings.REGISTRY) # E: Revealed type is 'Union[main.Class, None]' reveal_type(settings.REGISTRY) # E: Revealed type is 'Union[main.Class, None]'
reveal_type(settings.LIST) # E: Revealed type is 'Any'
reveal_type(settings.BASE_LIST) # E: Revealed type is 'Any'
[out]
main:5: error: "LazySettings" has no attribute "LIST"
main:6: error: "LazySettings" has no attribute "BASE_LIST"
mysettings:4: error: Need type annotation for 'LIST'
base:6: error: Need type annotation for 'BASE_LIST'
[env DJANGO_SETTINGS_MODULE=mysettings] [env DJANGO_SETTINGS_MODULE=mysettings]
[file mysettings.py] [file mysettings.py]
from typing import TYPE_CHECKING
from typing import Optional from typing import Optional
from base import *
LIST = ['1', '2']
[file base.py]
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from main import Class
REGISTRY: Optional['Class'] = None
BASE_LIST = ['3', '4']
[CASE test_circular_dependency_in_settings_works_if_settings_have_annotations]
from django.conf import settings
class Class:
pass
reveal_type(settings.MYSETTING) # E: Revealed type is 'builtins.int'
reveal_type(settings.REGISTRY) # E: Revealed type is 'Union[main.Class, None]'
reveal_type(settings.LIST) # E: Revealed type is 'builtins.list[builtins.str]'
[out]
[env DJANGO_SETTINGS_MODULE=mysettings]
[file mysettings.py]
from typing import TYPE_CHECKING, Optional, List
if TYPE_CHECKING: if TYPE_CHECKING:
from main import Class from main import Class
MYSETTING = 1122 MYSETTING = 1122
REGISTRY: Optional['Class'] = None REGISTRY: Optional['Class'] = None
LIST: List[str] = ['1', '2']