Files
django-stubs/mypy_django_plugin/main.py
2019-02-10 04:32:27 +03:00

227 lines
9.4 KiB
Python

import os
from configparser import ConfigParser
from typing import Callable, Dict, Optional, Set, cast
from dataclasses import dataclass
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo
from mypy.options import Options
from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import Instance, Type
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
from mypy_django_plugin.plugins.models import process_model_class
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
from mypy_django_plugin.plugins.settings import AddSettingValuesToDjangoConfObject
def transform_model_class(ctx: ClassDefContext) -> None:
try:
sym = ctx.api.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME)
except KeyError:
# models.Model is not loaded, skip metadata model write
pass
else:
if sym is not None and isinstance(sym.node, TypeInfo):
sym.node.metadata['django']['model_bases'][ctx.cls.fullname] = 1
process_model_class(ctx)
def transform_manager_class(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(helpers.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
sym.node.metadata['django']['manager_bases'][ctx.cls.fullname] = 1
def determine_proper_manager_type(ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
ret = ctx.default_return_type
if not api.tscope.classes:
# not in class
return ret
outer_model_info = api.tscope.classes[0]
if not outer_model_info.has_base(helpers.MODEL_CLASS_FULLNAME):
return ret
if not isinstance(ret, Instance):
return ret
for i, base in enumerate(ret.type.bases):
if base.type.fullname() in {helpers.MANAGER_CLASS_FULLNAME,
helpers.RELATED_MANAGER_CLASS_FULLNAME,
helpers.BASE_MANAGER_CLASS_FULLNAME}:
ret.type.bases[i] = reparametrize_with(base, [Instance(outer_model_info, [])])
return ret
return ret
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
pointer_args: Set[str] = set()
for base in model.bases:
if base.type.has_base(helpers.MODEL_CLASS_FULLNAME):
parent_name = base.type.name().lower()
pointer_args.add(f'{parent_name}_ptr')
pointer_args.add(f'{parent_name}_ptr_id')
return pointer_args
def redefine_model_init(ctx: FunctionContext) -> Type:
assert isinstance(ctx.default_return_type, Instance)
api = cast(TypeChecker, ctx.api)
model: TypeInfo = ctx.default_return_type.type
expected_types = helpers.extract_expected_types(ctx, model)
# order is preserved, can use for positionals
positional_names = list(expected_types.keys())
positional_names.remove('pk')
visited_positionals = set()
# check positionals
for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
actual_pos_name = positional_names[i]
api.check_subtype(actual_pos_type, expected_types[actual_pos_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_pos_name,
model.name()),
'got', 'expected')
visited_positionals.add(actual_pos_name)
# extract name of base models for _ptr
base_pointer_args = extract_base_pointer_args(model)
# check kwargs
for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])):
if actual_name in base_pointer_args:
# parent_ptr args are not supported
continue
if actual_name in visited_positionals:
continue
if actual_name is None:
# unpacked dict as kwargs is not supported
continue
if actual_name not in expected_types:
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model.name()),
ctx.context)
continue
api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model.name()),
'got', 'expected')
return ctx.default_return_type
@dataclass
class Config:
django_settings_module: Optional[str] = None
ignore_missing_settings: bool = False
@classmethod
def from_config_file(self, fpath: str) -> 'Config':
ini_config = ConfigParser()
ini_config.read(fpath)
if not ini_config.has_section('mypy_django_plugin'):
raise ValueError('Invalid config file: no [mypy_django_plugin] section')
return Config(django_settings_module=ini_config.get('mypy_django_plugin', 'django_settings',
fallback=None),
ignore_missing_settings=ini_config.get('mypy_django_plugin', 'ignore_missing_settings',
fallback=False))
class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
monkeypatch.restore_original_load_graph()
monkeypatch.restore_original_dependencies_handling()
config_fpath = os.environ.get('MYPY_DJANGO_CONFIG')
if config_fpath:
self.config = Config.from_config_file(config_fpath)
self.django_settings = self.config.django_settings_module
else:
self.config = Config()
self.django_settings = None
if 'DJANGO_SETTINGS_MODULE' in os.environ:
self.django_settings = os.environ['DJANGO_SETTINGS_MODULE']
settings_modules = ['django.conf.global_settings']
if self.django_settings:
settings_modules.append(self.django_settings)
monkeypatch.add_modules_as_a_source_seed_files(settings_modules)
monkeypatch.inject_modules_as_dependencies_for_django_conf_settings(settings_modules)
def _get_current_model_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
if 'django' not in model_sym.node.metadata:
model_sym.node.metadata['django'] = {
'model_bases': {helpers.MODEL_CLASS_FULLNAME: 1}
}
return model_sym.node.metadata['django']['model_bases']
else:
return {}
def _get_current_manager_bases(self) -> Dict[str, int]:
manager_sym = self.lookup_fully_qualified(helpers.MANAGER_CLASS_FULLNAME)
if manager_sym is not None and isinstance(manager_sym.node, TypeInfo):
if 'django' not in manager_sym.node.metadata:
manager_sym.node.metadata['django'] = {
'manager_bases': {helpers.MANAGER_CLASS_FULLNAME: 1}
}
return manager_sym.node.metadata['django']['manager_bases']
else:
return {}
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname in {helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME,
helpers.MANYTOMANY_FIELD_FULLNAME}:
return extract_to_parameter_as_get_ret_type_for_related_field
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field
manager_bases = self._get_current_manager_bases()
if fullname in manager_bases:
return determine_proper_manager_type
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
if sym.node.metadata.get('django', {}).get('generated_init'):
return redefine_model_init
def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], Type]]:
if fullname in {'django.apps.registry.Apps.get_model',
'django.db.migrations.state.StateApps.get_model'}:
return determine_model_cls_from_string_for_migrations
return None
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in self._get_current_model_bases():
return transform_model_class
if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS:
settings_modules = ['django.conf.global_settings']
if self.django_settings:
settings_modules.append(self.django_settings)
return AddSettingValuesToDjangoConfObject(settings_modules,
self.config.ignore_missing_settings)
if fullname in self._get_current_manager_bases():
return transform_manager_class
return None
def plugin(version):
return DjangoPlugin