solve more use cases for related managers and settings

This commit is contained in:
Maxim Kurnikov
2018-12-03 01:57:46 +03:00
parent fcd659837e
commit 3676cb3ac0
9 changed files with 222 additions and 148 deletions

4
mypy.ini Normal file
View File

@@ -0,0 +1,4 @@
[mypy]
[mypy-mypy_django_plugin.monkeypatch.*]
ignore_errors = True

View File

@@ -1,9 +1,7 @@
import typing
from typing import Dict, Optional, NamedTuple
from typing import Dict, Optional
from mypy.nodes import Expression, StrExpr, MypyFile, TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import Type, UnionType, NoneTyp
from mypy.nodes import StrExpr, MypyFile, TypeInfo, ImportedName, SymbolNode
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
@@ -11,63 +9,14 @@ FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
Argument = NamedTuple('Argument', fields=[
('arg', Expression),
('arg_type', Type)
])
def get_call_signature_or_none(ctx: FunctionContext) -> Optional[Dict[str, Argument]]:
result: Dict[str, Argument] = {}
positional_args_only = []
positional_arg_types_only = []
for arg, arg_name, arg_type in zip(ctx.args, ctx.arg_names, ctx.arg_types):
if arg_name is None:
positional_args_only.append(arg)
positional_arg_types_only.append(arg_type)
continue
if len(arg) == 0 or len(arg_type) == 0:
continue
result[arg_name] = (arg[0], arg_type[0])
callee = ctx.context.callee
if '__init__' not in callee.node.names:
return None
init_type = callee.node.names['__init__'].type
arg_names = init_type.arg_names[1:]
for arg, arg_name, arg_type in zip(positional_args_only,
arg_names[:len(positional_args_only)],
positional_arg_types_only):
result[arg_name] = (arg[0], arg_type[0])
return result
def make_optional(typ: Type) -> Type:
return UnionType.make_simplified_union([typ, NoneTyp()])
def make_required(typ: Type) -> Type:
if not isinstance(typ, UnionType):
return typ
items = [item for item in typ.items if not isinstance(item, NoneTyp)]
return UnionType.make_union(items)
def get_obj_type_name(typ: typing.Type) -> str:
return typ.__module__ + '.' + typ.__qualname__
def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]:
models_module = '.'.join([app_name, 'models'])
return all_modules.get(models_module)
def get_model_type_from_string(expr: StrExpr,
all_modules: Dict[str, MypyFile]) -> Optional[TypeInfo]:
def get_model_fullname_from_string(expr: StrExpr,
all_modules: Dict[str, MypyFile]) -> Optional[str]:
app_name, model_name = expr.value.split('.')
models_file = get_models_file(app_name, all_modules)
@@ -75,7 +24,26 @@ def get_model_type_from_string(expr: StrExpr,
# not imported so far, not supported
return None
sym = models_file.names.get(model_name)
if not sym or not isinstance(sym.node, TypeInfo):
# no such model found in the app / node is not a class definition
if not sym:
return None
if isinstance(sym.node, TypeInfo):
return sym.node.fullname()
elif isinstance(sym.node, ImportedName):
return sym.node.target_fullname
else:
return None
def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]:
if '.' not in name:
return None
module, cls_name = name.rsplit('.', 1)
module_file = all_modules.get(module)
if module_file is None:
return None
sym = module_file.names.get(cls_name)
if sym is None:
return None
return sym.node

View File

@@ -42,9 +42,6 @@ class DjangoPlugin(Plugin):
helpers.ONETOONE_FIELD_FULLNAME}:
return extract_to_parameter_as_get_ret_type_for_related_field
# if fullname == helpers.ONETOONE_FIELD_FULLNAME:
# return OneToOneFieldHook(settings=self.django_settings)
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field
return None

View File

@@ -1,7 +1,7 @@
from typing import cast, Iterator, Tuple, Optional
from typing import cast, Iterator, Tuple, Optional, Dict
from mypy.nodes import ClassDef, AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr, MDEF, TypeInfo, Var, SymbolTableNode, \
Lvalue, Expression, Statement
Lvalue, Expression, MypyFile
from mypy.plugin import ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Instance
@@ -30,12 +30,11 @@ def iter_over_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, Expression]
def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]:
for lvalue, rvalue in iter_over_assignments(klass):
if not isinstance(rvalue, CallExpr):
continue
if isinstance(rvalue, CallExpr):
yield lvalue, rvalue
def iter_over_one_to_n_related_fields(klass: ClassDef, api: SemanticAnalyzerPass2) -> Iterator[Tuple[NameExpr, CallExpr]]:
def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExpr, CallExpr]]:
for lvalue, rvalue in iter_call_assignments(klass):
if (isinstance(lvalue, NameExpr)
and isinstance(rvalue.callee, MemberExpr)):
@@ -67,7 +66,7 @@ def is_abstract_model(ctx: ClassDefContext) -> bool:
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
api = ctx.api
for lvalue, rvalue in iter_over_one_to_n_related_fields(ctx.cls, api):
for lvalue, rvalue in iter_over_one_to_n_related_fields(ctx.cls):
property_name = lvalue.name + '_id'
add_new_var_node_to_class(ctx.cls.info, property_name,
typ=api.named_type('__builtins__.int'))
@@ -112,68 +111,67 @@ def inject_any_as_base_for_nested_class_meta(ctx: ClassDefContext) -> None:
meta_node.fallback_to_any = True
def is_model_defn(defn: Statement, api: SemanticAnalyzerPass2) -> bool:
if not isinstance(defn, ClassDef):
return False
for base_type_expr in defn.base_type_exprs:
# api.accept(base_type_expr)
fullname = getattr(base_type_expr, 'fullname', None)
if fullname == helpers.MODEL_CLASS_FULLNAME:
return True
return False
def iter_over_models(ctx: ClassDefContext) -> Iterator[ClassDef]:
for module_name, module_file in ctx.api.modules.items():
def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]:
for defn in module_file.defs:
if is_model_defn(defn, api=cast(SemanticAnalyzerPass2, ctx.api)):
if isinstance(defn, ClassDef):
yield defn
def extract_to_value_or_none(field_expr: CallExpr, ctx: ClassDefContext) -> Optional[TypeInfo]:
if 'to' in field_expr.arg_names:
ref_expr = field_expr.args[field_expr.arg_names.index('to')]
else:
# first positional argument
ref_expr = field_expr.args[0]
if isinstance(ref_expr, StrExpr):
model_typeinfo = helpers.get_model_type_from_string(ref_expr,
all_modules=ctx.api.modules)
return model_typeinfo
elif isinstance(ref_expr, NameExpr):
return ref_expr.node
def get_related_field_type(rvalue: CallExpr, api: SemanticAnalyzerPass2,
related_model_typ: TypeInfo) -> Optional[Instance]:
if rvalue.callee.fullname == helpers.FOREIGN_KEY_FULLNAME:
if rvalue.callee.name == 'ForeignKey':
return api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME,
args=[Instance(related_model_typ, [])])
else:
return Instance(related_model_typ, [])
def add_related_managers(ctx: ClassDefContext) -> None:
for model_defn in iter_over_models(ctx):
for _, rvalue in iter_over_one_to_n_related_fields(model_defn, ctx.api):
if 'related_name' not in rvalue.arg_names:
# positional related_name is not supported yet
return
related_name = rvalue.args[rvalue.arg_names.index('related_name')].value
ref_to_typ = extract_to_value_or_none(rvalue, ctx)
if ref_to_typ is not None:
if ref_to_typ.fullname() == ctx.cls.info.fullname():
typ = get_related_field_type(rvalue, ctx.api,
related_model_typ=model_defn.info)
def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool:
if isinstance(expr.callee, MemberExpr) and isinstance(expr.callee.expr, NameExpr):
module = module_file.names[expr.callee.expr.name]
if module.fullname == 'django.db.models' and expr.callee.name in {'ForeignKey', 'OneToOneField'}:
return True
return False
def extract_ref_to_fullname(rvalue_expr: CallExpr,
module_file: MypyFile, all_modules: Dict[str, MypyFile]) -> Optional[str]:
if 'to' in rvalue_expr.arg_names:
to_expr = rvalue_expr.args[rvalue_expr.arg_names.index('to')]
else:
to_expr = rvalue_expr.args[0]
if isinstance(to_expr, NameExpr):
return module_file.names[to_expr.name].fullname
elif isinstance(to_expr, StrExpr):
typ_fullname = helpers.get_model_fullname_from_string(to_expr, all_modules)
if typ_fullname is None:
return None
return typ_fullname
return None
def add_related_managers(ctx: ClassDefContext):
api = cast(SemanticAnalyzerPass2, ctx.api)
for module_name, module_file in ctx.api.modules.items():
for defn in iter_over_classdefs(module_file):
for lvalue, rvalue in iter_call_assignments(defn):
if is_related_field(rvalue, module_file):
ref_to_fullname = extract_ref_to_fullname(rvalue, module_file=module_file,
all_modules=api.modules)
if ctx.cls.fullname == ref_to_fullname:
if 'related_name' in rvalue.arg_names:
related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')]
if not isinstance(related_name_expr, StrExpr):
return None
related_name = related_name_expr.value
typ = get_related_field_type(rvalue, api, defn.info)
if typ is None:
return
return None
add_new_var_node_to_class(ctx.cls.info, related_name, typ)
def process_model_class(ctx: ClassDefContext) -> None:
# add_related_managers(ctx)
add_related_managers(ctx)
inject_any_as_base_for_nested_class_meta(ctx)
set_fieldname_attrs_for_related_fields(ctx)
add_int_id_attribute_if_primary_key_true_is_not_present(ctx)

View File

@@ -2,7 +2,7 @@ import typing
from typing import Optional, cast
from mypy.checker import TypeChecker
from mypy.nodes import StrExpr
from mypy.nodes import StrExpr, TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny
@@ -18,9 +18,10 @@ def fill_typevars_with_any(instance: Instance) -> Type:
def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.arg_names:
# shouldn't happen, invalid code
ctx.api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}',
context=ctx.context)
return None
@@ -30,13 +31,19 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]:
if not isinstance(to_arg_expr, StrExpr):
# not string, not supported
return None
model_info = helpers.get_model_type_from_string(to_arg_expr,
all_modules=cast(TypeChecker, ctx.api).modules)
if model_info is None:
model_fullname = helpers.get_model_fullname_from_string(to_arg_expr,
all_modules=api.modules)
if model_fullname is None:
return None
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
all_modules=api.modules)
if model_info is None or not isinstance(model_info, TypeInfo):
return None
return Instance(model_info, [])
referred_to_type = arg_type.ret_type
if not isinstance(referred_to_type, Instance):
return None
for base in referred_to_type.type.bases:
if base.type.fullname() == helpers.MODEL_CLASS_FULLNAME:
break

View File

@@ -1,12 +1,10 @@
from typing import cast, List
from typing import cast, List, Optional
from mypy.nodes import Var, Context, SymbolNode, SymbolTableNode
from mypy.plugin import ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Instance, UnionType, NoneTyp, Type
from mypy_django_plugin import helpers
def get_error_context(node: SymbolNode) -> Context:
context = Context()
@@ -18,10 +16,24 @@ def filter_out_nones(typ: UnionType) -> List[Type]:
return [item for item in typ.items if not isinstance(item, NoneTyp)]
def copy_sym_of_instance(sym: SymbolTableNode) -> SymbolTableNode:
def make_sym_copy_of_setting(sym: SymbolTableNode) -> Optional[SymbolTableNode]:
if isinstance(sym.type, Instance):
copied = sym.copy()
copied.node.info = sym.type.type
return copied
elif isinstance(sym.type, UnionType):
instances = filter_out_nones(sym.type)
if len(instances) > 1:
# plain unions not supported yet
return None
typ = instances[0]
if isinstance(typ, Instance):
copied = sym.copy()
copied.node.info = typ.type
return copied
return None
else:
return None
def add_settings_to_django_conf_object(ctx: ClassDefContext,
@@ -33,25 +45,24 @@ def add_settings_to_django_conf_object(ctx: ClassDefContext,
settings_file = api.modules[settings_module]
for name, sym in settings_file.names.items():
if name.isupper() and isinstance(sym.node, Var):
if isinstance(sym.type, Instance):
copied = sym.copy()
copied.node.info = sym.type.type
ctx.cls.info.names[name] = copied
elif isinstance(sym.type, UnionType):
instances = filter_out_nones(sym.type)
if len(instances) > 1:
# plain unions not supported yet
if sym.type is not None:
copied = make_sym_copy_of_setting(sym)
if copied is None:
continue
typ = instances[0]
if isinstance(typ, Instance):
copied = sym.copy()
copied.node.info = typ.type
ctx.cls.info.names[name] = copied
else:
context = Context()
module, node_name = sym.node.fullname().rsplit('.', 1)
module_file = api.modules.get(module)
if module_file is None:
return None
context.set_line(sym.node)
api.msg.report(f"Need type annotation for '{sym.node.name()}'", context,
severity='error', file=module_file.path)
class DjangoConfSettingsInitializerHook(object):
def __init__(self, settings_module: str):
def __init__(self, settings_module: Optional[str]):
self.settings_module = settings_module
def __call__(self, ctx: ClassDefContext) -> None:

View File

@@ -5,5 +5,6 @@ python_files = test*.py
addopts =
--tb=native
--mypy-ini-file=./test-data/plugins.ini
--mypy-no-cache
-s
-v

View File

@@ -119,7 +119,7 @@ class App(models.Model):
def method(self) -> None:
reveal_type(self.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]'
[case models_related_managers_work_with_direct_model_inheritance_and_with_inheritance_from_other_model]
[CASE models_related_managers_work_with_direct_model_inheritance_and_with_inheritance_from_other_model]
from django.db.models import Model
from django.db import models
@@ -135,3 +135,61 @@ class View2(View):
reveal_type(App().views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]'
reveal_type(App().views2) # E: Revealed type is 'django.db.models.query.QuerySet[main.View2]'
[out]
[CASE models_imported_inside_init_file]
from django.db import models
from myapp.models import App
class View(models.Model):
app = models.ForeignKey(to='myapp.App', related_name='views', on_delete=models.CASCADE)
reveal_type(View().app.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]'
[file myapp/__init__.py]
[file myapp/models/__init__.py]
from .app import App
[file myapp/models/app.py]
from django.db import models
class App(models.Model):
pass
[CASE models_imported_inside_init_file_one_to_one_field]
from django.db import models
from myapp.models import User
class Profile(models.Model):
user = models.OneToOneField(to='myapp.User', related_name='profile', on_delete=models.CASCADE)
reveal_type(Profile().user.profile) # E: Revealed type is 'main.Profile'
[file myapp/__init__.py]
[file myapp/models/__init__.py]
from .user import User
[file myapp/models/user.py]
from django.db import models
class User(models.Model):
pass
[CASE models_triple_circular_reference]
from myapp.models import App
reveal_type(App().owner.profile) # E: Revealed type is 'myapp.models.profile.Profile'
[file myapp/__init__.py]
[file myapp/models/__init__.py]
from .user import User
from .profile import Profile
from .app import App
[file myapp/models/user.py]
from django.db import models
class User(models.Model):
pass
[file myapp/models/profile.py]
from django.db import models
from myapp.models import User
class Profile(models.Model):
user = models.OneToOneField(to='myapp.User', related_name='profile', on_delete=models.CASCADE)
[file myapp/models/app.py]
from django.db import models
class App(models.Model):
owner = models.ForeignKey(to='myapp.User', on_delete=models.CASCADE, related_name='apps')

View File

@@ -36,17 +36,47 @@ ROOT_DIR = Path(__file__)
from django.conf import settings
class Class:
pass
reveal_type(settings.MYSETTING) # E: Revealed type is 'builtins.int'
reveal_type(settings.REGISTRY) # E: Revealed type is 'Union[main.Class, None]'
reveal_type(settings.LIST) # E: Revealed type is 'Any'
reveal_type(settings.BASE_LIST) # E: Revealed type is 'Any'
[out]
main:5: error: "LazySettings" has no attribute "LIST"
main:6: error: "LazySettings" has no attribute "BASE_LIST"
mysettings:4: error: Need type annotation for 'LIST'
base:6: error: Need type annotation for 'BASE_LIST'
[env DJANGO_SETTINGS_MODULE=mysettings]
[file mysettings.py]
from typing import TYPE_CHECKING
from typing import Optional
from base import *
LIST = ['1', '2']
[file base.py]
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from main import Class
REGISTRY: Optional['Class'] = None
BASE_LIST = ['3', '4']
[CASE test_circular_dependency_in_settings_works_if_settings_have_annotations]
from django.conf import settings
class Class:
pass
reveal_type(settings.MYSETTING) # E: Revealed type is 'builtins.int'
reveal_type(settings.REGISTRY) # E: Revealed type is 'Union[main.Class, None]'
reveal_type(settings.LIST) # E: Revealed type is 'builtins.list[builtins.str]'
[out]
[env DJANGO_SETTINGS_MODULE=mysettings]
[file mysettings.py]
from typing import TYPE_CHECKING, Optional, List
if TYPE_CHECKING:
from main import Class
MYSETTING = 1122
REGISTRY: Optional['Class'] = None
LIST: List[str] = ['1', '2']