latest changes

This commit is contained in:
Maxim Kurnikov
2018-11-26 23:58:34 +03:00
parent 348efcd371
commit f59cfe6371
34 changed files with 1558 additions and 132 deletions

View File

@@ -1,15 +1,15 @@
from typing import Dict, Optional, NamedTuple
import typing
from typing import Dict, Optional, NamedTuple, Any
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Type
from mypy.nodes import SymbolTableNode, Var, Expression
from mypy.plugin import FunctionContext
from mypy.types import Instance, UnionType, NoneTyp
from mypy.types import Type, Instance, UnionType, NoneTyp
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
def create_new_symtable_node(name: str, kind: int, instance: Instance) -> SymbolTableNode:
@@ -26,12 +26,10 @@ Argument = NamedTuple('Argument', fields=[
def get_call_signature_or_none(ctx: FunctionContext) -> Optional[Dict[str, Argument]]:
arg_names = ctx.context.arg_names
result: Dict[str, Argument] = {}
positional_args_only = []
positional_arg_types_only = []
for arg, arg_name, arg_type in zip(ctx.args, arg_names, ctx.arg_types):
for arg, arg_name, arg_type in zip(ctx.args, ctx.arg_names, ctx.arg_types):
if arg_name is None:
positional_args_only.append(arg)
positional_arg_types_only.append(arg_type)
@@ -64,4 +62,8 @@ def make_required(typ: Type) -> Type:
if not isinstance(typ, UnionType):
return typ
items = [item for item in typ.items if not isinstance(item, NoneTyp)]
return UnionType.make_union(items)
return UnionType.make_union(items)
def get_obj_type_name(typ: typing.Type) -> str:
return typ.__module__ + '.' + typ.__qualname__

View File

@@ -1,27 +1,42 @@
import os
from typing import Callable, Optional
from typing import Callable, Optional, List
from django.apps.registry import Apps
from django.conf import Settings
from mypy import build
from mypy.build import BuildManager
from mypy.options import Options
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
from mypy.types import Type
from mypy_django_plugin import helpers
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.plugins.objects_queryset import set_objects_queryset_to_model_class
from mypy_django_plugin.plugins.postgres_fields import determine_type_of_array_field
from mypy_django_plugin.plugins.related_fields import set_related_name_instance_for_onetoonefield, \
set_related_name_manager_for_foreign_key, set_fieldname_attrs_for_related_fields
from mypy_django_plugin.plugins.related_fields import OneToOneFieldHook, \
ForeignKeyHook, set_fieldname_attrs_for_related_fields
from mypy_django_plugin.plugins.setup_settings import DjangoConfSettingsInitializerHook
base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
def transform_model_class(ctx: ClassDefContext) -> None:
base_model_classes.add(ctx.cls.fullname)
class TransformModelClassHook(object):
def __init__(self, settings: Settings, apps: Apps):
self.settings = settings
self.apps = apps
set_fieldname_attrs_for_related_fields(ctx)
set_objects_queryset_to_model_class(ctx)
def __call__(self, ctx: ClassDefContext) -> None:
base_model_classes.add(ctx.cls.fullname)
set_fieldname_attrs_for_related_fields(ctx)
set_objects_queryset_to_model_class(ctx)
def always_return_none(manager: BuildManager):
return None
build.read_plugins_snapshot = always_return_none
class DjangoPlugin(Plugin):
@@ -29,18 +44,36 @@ class DjangoPlugin(Plugin):
options: Options) -> None:
super().__init__(options)
self.django_settings = None
self.apps = None
monkeypatch.replace_apply_function_plugin_method()
django_settings_module = os.environ.get('DJANGO_SETTINGS_MODULE')
if django_settings_module:
self.django_settings = Settings(django_settings_module)
# import django
# django.setup()
#
# from django.apps import apps
# self.apps = apps
#
# models_modules = []
# for app_config in self.apps.app_configs.values():
# models_modules.append(app_config.module.__name__ + '.' + 'models')
#
# monkeypatch.state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(django_settings_module,
# models_modules)
monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(django_settings_module)
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == helpers.FOREIGN_KEY_FULLNAME:
return set_related_name_manager_for_foreign_key
return ForeignKeyHook(settings=self.django_settings,
apps=self.apps)
if fullname == helpers.ONETOONE_FIELD_FULLNAME:
return set_related_name_instance_for_onetoonefield
return OneToOneFieldHook(settings=self.django_settings,
apps=self.apps)
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field
@@ -49,9 +82,11 @@ class DjangoPlugin(Plugin):
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in base_model_classes:
return transform_model_class
if fullname == 'django.conf._DjangoConfLazyObject':
return TransformModelClassHook(self.django_settings, self.apps)
if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS:
return DjangoConfSettingsInitializerHook(settings=self.django_settings)
return None

View File

@@ -0,0 +1,112 @@
from typing import Optional, List, Sequence
from mypy.build import BuildManager, Graph, State
from mypy.modulefinder import BuildSource
from mypy.nodes import Expression, Context
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import Type, CallableType, Instance
def state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(settings_module: str,
models_py_modules: List[str]):
from mypy.build import State
old_compute_dependencies = State.compute_dependencies
def patched_compute_dependencies(self: State):
old_compute_dependencies(self)
if self.id == settings_module:
self.dependencies.extend(models_py_modules)
State.compute_dependencies = patched_compute_dependencies
def load_graph_to_add_settings_file_as_a_source_seed(settings_module: str):
from mypy import build
old_load_graph = build.load_graph
def patched_load_graph(sources: List[BuildSource], manager: BuildManager,
old_graph: Optional[Graph] = None,
new_modules: Optional[List[State]] = None):
if all([source.module != settings_module for source in sources]):
sources.append(BuildSource(None, settings_module, None))
return old_load_graph(sources=sources, manager=manager,
old_graph=old_graph,
new_modules=new_modules)
build.load_graph = patched_load_graph
def replace_apply_function_plugin_method():
def apply_function_plugin(self,
arg_types: List[Type],
inferred_ret_type: Type,
arg_names: Optional[Sequence[Optional[str]]],
formal_to_actual: List[List[int]],
args: List[Expression],
num_formals: int,
fullname: str,
object_type: Optional[Type],
context: Context) -> Type:
"""Use special case logic to infer the return type of a specific named function/method.
Caller must ensure that a plugin hook exists. There are two different cases:
- If object_type is None, the caller must ensure that a function hook exists
for fullname.
- If object_type is not None, the caller must ensure that a method hook exists
for fullname.
Return the inferred return type.
"""
formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]]
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
formal_arg_names = [None for _ in range(num_formals)] # type: List[Optional[str]]
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_types[formal].append(arg_types[actual])
formal_arg_exprs[formal].append(args[actual])
if arg_names:
formal_arg_names[formal] = arg_names[actual]
num_passed_positionals = sum([1 if name is None else 0
for name in formal_arg_names])
if arg_names and num_passed_positionals > 0:
object_type_info = None
if object_type is not None:
if isinstance(object_type, CallableType):
# class object, convert to corresponding Instance
object_type = object_type.ret_type
if isinstance(object_type, Instance):
# skip TypedDictType and others
object_type_info = object_type.type
defn_arg_names = self._get_defn_arg_names(fullname, object_type=object_type_info)
if defn_arg_names:
if num_formals < len(defn_arg_names):
# self/cls argument has been passed implicitly
defn_arg_names = defn_arg_names[1:]
formal_arg_names[:num_passed_positionals] = defn_arg_names[:num_passed_positionals]
if object_type is None:
# Apply function plugin
callback = self.plugin.get_function_hook(fullname)
assert callback is not None # Assume that caller ensures this
return callback(
FunctionContext(formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
else:
# Apply method plugin
method_callback = self.plugin.get_method_hook(fullname)
assert method_callback is not None # Assume that caller ensures this
return method_callback(
MethodContext(object_type, formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
from mypy.checkexpr import ExpressionChecker
ExpressionChecker.apply_function_plugin = apply_function_plugin

View File

@@ -9,19 +9,28 @@ from mypy_django_plugin import helpers
def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None:
if 'objects' in ctx.cls.info.names:
return
api = cast(SemanticAnalyzerPass2, ctx.api)
# search over mro
objects_sym = ctx.cls.info.get('objects')
if objects_sym is not None:
return None
metaclass_node = ctx.cls.info.names.get('Meta')
if metaclass_node is not None:
for stmt in metaclass_node.node.defn.defs.body:
# only direct Meta class
metaclass_sym = ctx.cls.info.names.get('Meta')
# skip if abstract
if metaclass_sym is not None:
for stmt in metaclass_sym.node.defn.defs.body:
if (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1
and stmt.lvalues[0].name == 'abstract'):
is_abstract = api.parse_bool(stmt.rvalue)
is_abstract = ctx.api.parse_bool(stmt.rvalue)
if is_abstract:
return
return None
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, args=[Instance(ctx.cls.info, [])])
new_objects_node = helpers.create_new_symtable_node('objects', MDEF, instance=typ)
ctx.cls.info.names['objects'] = new_objects_node
api = cast(SemanticAnalyzerPass2, ctx.api)
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME,
args=[Instance(ctx.cls.info, [])])
if not typ:
return None
ctx.cls.info.names['objects'] = helpers.create_new_symtable_node('objects',
kind=MDEF,
instance=typ)

View File

@@ -1,14 +1,11 @@
from mypy.plugin import FunctionContext
from mypy.types import Type
from mypy_django_plugin import helpers
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
signature = helpers.get_call_signature_or_none(ctx)
if signature is None:
if 'base_field' not in ctx.arg_names:
return ctx.default_return_type
_, base_field_arg_type = signature['base_field']
base_field_arg_type = ctx.arg_types[ctx.arg_names.index('base_field')][0]
return ctx.api.named_generic_type(ctx.context.callee.fullname,
args=[base_field_arg_type.type.names['__get__'].type.ret_type])

View File

@@ -1,23 +1,63 @@
from typing import Optional, cast
import typing
from typing import Optional, cast, Tuple, Any
from django.apps.registry import Apps
from django.conf import Settings
from django.db import models
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, MemberExpr
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, StrExpr
from mypy.plugin import FunctionContext, ClassDefContext
from mypy.types import Type, CallableType, Instance, AnyType
from mypy_django_plugin import helpers
def extract_to_value_type(ctx: FunctionContext) -> Optional[Instance]:
signature = helpers.get_call_signature_or_none(ctx)
if signature is None or 'to' not in signature:
return None
def get_instance_type_for_class(klass: typing.Type[models.Model],
api: TypeChecker) -> Optional[Instance]:
model_qualname = helpers.get_obj_type_name(klass)
module_name, _, class_name = model_qualname.rpartition('.')
module = api.modules.get(module_name)
if not module or class_name not in module.names:
return
arg, arg_type = signature['to']
if not isinstance(arg_type, CallableType):
return None
sym = module.names[class_name]
return Instance(sym.node, [])
return arg_type.ret_type
def extract_to_value_type(ctx: FunctionContext,
apps: Optional[Apps]) -> Tuple[Optional[Instance], bool]:
api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.arg_names:
return None, False
arg = ctx.args[ctx.arg_names.index('to')][0]
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
if isinstance(arg_type, CallableType):
return arg_type.ret_type, False
if apps:
if isinstance(arg, StrExpr):
arg_value = arg.value
if '.' not in arg_value:
return None, False
app_label, modelname = arg_value.lower().split('.')
try:
model_cls = apps.get_model(app_label, modelname)
except LookupError:
# no model class found
return None, False
try:
instance = get_instance_type_for_class(model_cls, api=api)
if not instance:
return None, False
return instance, True
except AssertionError:
pass
return None, False
def extract_related_name_value(ctx: FunctionContext) -> str:
@@ -30,45 +70,58 @@ def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instanc
instance=new_member_instance)
def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
outer_class_info = api.tscope.classes[-1]
class ForeignKeyHook(object):
def __init__(self, settings: Settings, apps: Apps):
self.settings = settings
self.apps = apps
def __call__(self, ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
outer_class_info = api.tscope.classes[-1]
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps)
if not referred_to:
return ctx.default_return_type
if 'related_name' in ctx.context.arg_names:
related_name = extract_related_name_value(ctx)
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
args=[Instance(outer_class_info, [])])
if isinstance(referred_to, AnyType):
return ctx.default_return_type
add_new_class_member(referred_to.type,
related_name, queryset_type)
if is_string_based:
return referred_to
if 'related_name' not in ctx.context.arg_names:
return ctx.default_return_type
referred_to = extract_to_value_type(ctx)
if not referred_to:
class OneToOneFieldHook(object):
def __init__(self, settings: Optional[Settings], apps: Optional[Apps]):
self.settings = settings
self.apps = apps
def __call__(self, ctx: FunctionContext) -> Type:
if 'related_name' not in ctx.context.arg_names:
return ctx.default_return_type
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps)
if referred_to is None:
return ctx.default_return_type
if 'related_name' in ctx.context.arg_names:
related_name = extract_related_name_value(ctx)
outer_class_info = ctx.api.tscope.classes[-1]
add_new_class_member(referred_to.type, related_name,
new_member_instance=Instance(outer_class_info, []))
if is_string_based:
return referred_to
return ctx.default_return_type
related_name = extract_related_name_value(ctx)
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
args=[Instance(outer_class_info, [])])
if isinstance(referred_to, AnyType):
# referred_to defined as string, which is unsupported for now
return ctx.default_return_type
add_new_class_member(referred_to.type,
related_name, queryset_type)
return ctx.default_return_type
def set_related_name_instance_for_onetoonefield(ctx: FunctionContext) -> Type:
if 'related_name' not in ctx.context.arg_names:
return ctx.default_return_type
referred_to = extract_to_value_type(ctx)
if referred_to is None:
return ctx.default_return_type
related_name = extract_related_name_value(ctx)
outer_class_info = ctx.api.tscope.classes[-1]
api = cast(TypeChecker, ctx.api)
add_new_class_member(referred_to.type, related_name,
new_member_instance=api.named_type(outer_class_info.fullname()))
return ctx.default_return_type
def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
api = ctx.api

View File

@@ -1,7 +1,7 @@
from typing import cast, Any
from typing import cast
from django.conf import Settings
from mypy.nodes import MDEF, TypeInfo, SymbolTable
from mypy.nodes import MDEF
from mypy.plugin import ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Instance, AnyType, TypeOfAny
@@ -9,26 +9,24 @@ from mypy.types import Instance, AnyType, TypeOfAny
from mypy_django_plugin import helpers
def get_obj_type_name(value: Any) -> str:
return type(value).__module__ + '.' + type(value).__qualname__
class DjangoConfSettingsInitializerHook(object):
def __init__(self, settings: Settings):
self.settings = settings
def __call__(self, ctx: ClassDefContext) -> None:
api = cast(SemanticAnalyzerPass2, ctx.api)
for name, value in self.settings.__dict__.items():
if name.isupper():
if value is None:
ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF,
instance=api.builtin_type('builtins.object'))
continue
if self.settings:
for name, value in self.settings.__dict__.items():
if name.isupper():
if value is None:
# TODO: change to Optional[Any] later
ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF,
instance=api.builtin_type('builtins.object'))
continue
type_fullname = get_obj_type_name(value)
sym = api.lookup_fully_qualified_or_none(type_fullname)
if sym is not None:
args = len(sym.node.type_vars) * [AnyType(TypeOfAny.from_omitted_generics)]
ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF,
instance=Instance(sym.node, args))
type_fullname = helpers.get_obj_type_name(type(value))
sym = api.lookup_fully_qualified_or_none(type_fullname)
if sym is not None:
args = len(sym.node.type_vars) * [AnyType(TypeOfAny.from_omitted_generics)]
ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF,
instance=Instance(sym.node, args))