Files
django-stubs/mypy_django_plugin/main.py
Seth Yastrov b1a04d2f7d Instead of using Literal types, overload QuerySet.values_list in the plugin. Fixes #43. (#44)
- Add a couple of extra type checks that Django makes:
  1) 'flat' and 'named' can't be used together.
  2) 'flat' is not valid when values_list is called with more than one field.
2019-03-13 22:10:37 +03:00

417 lines
17 KiB
Python

from functools import partial
import os
from typing import Callable, Dict, Optional, Union, cast
from mypy.checker import TypeChecker
from mypy.nodes import MemberExpr, NameExpr, TypeInfo
from mypy.options import Options
from mypy.plugin import (
AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin,
AnalyzeTypeContext)
from mypy.types import (
AnyType, CallableType, Instance, NoneTyp, Type, TypeOfAny, TypeType, UnionType,
)
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.config import Config
from mypy_django_plugin.transformers import fields, init_create
from mypy_django_plugin.transformers.forms import (
make_meta_nested_class_inherit_from_any,
)
from mypy_django_plugin.transformers.migrations import (
determine_model_cls_from_string_for_migrations, get_string_value_from_expr,
)
from mypy_django_plugin.transformers.models import process_model_class
from mypy_django_plugin.transformers.settings import (
AddSettingValuesToDjangoConfObject, get_settings_metadata,
)
def transform_model_class(ctx: ClassDefContext) -> None:
try:
sym = ctx.api.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME)
except KeyError:
# models.Model is not loaded, skip metadata model write
pass
else:
if sym is not None and isinstance(sym.node, TypeInfo):
sym.node.metadata['django']['model_bases'][ctx.cls.fullname] = 1
process_model_class(ctx)
def transform_manager_class(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(helpers.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
sym.node.metadata['django']['manager_bases'][ctx.cls.fullname] = 1
def transform_form_class(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(helpers.BASEFORM_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
sym.node.metadata['django']['baseform_bases'][ctx.cls.fullname] = 1
make_meta_nested_class_inherit_from_any(ctx)
def determine_proper_manager_type(ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
ret = ctx.default_return_type
if not api.tscope.classes:
# not in class
return ret
outer_model_info = api.tscope.classes[0]
if not outer_model_info.has_base(helpers.MODEL_CLASS_FULLNAME):
return ret
if not isinstance(ret, Instance):
return ret
has_manager_base = False
for i, base in enumerate(ret.type.bases):
if base.type.fullname() in {helpers.MANAGER_CLASS_FULLNAME,
helpers.RELATED_MANAGER_CLASS_FULLNAME,
helpers.BASE_MANAGER_CLASS_FULLNAME}:
has_manager_base = True
break
if has_manager_base:
# Fill in the manager's type argument from the outer model
new_type_args = [Instance(outer_model_info, [])]
return helpers.reparametrize_instance(ret, new_type_args)
else:
return ret
def set_first_generic_param_as_default_for_second(fullname: str, ctx: AnalyzeTypeContext) -> Type:
if not ctx.type.args:
return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit),
AnyType(TypeOfAny.explicit)])
args = ctx.type.args
if len(args) == 1:
args = [args[0], args[0]]
analyzed_args = [ctx.api.analyze_type(arg) for arg in args]
return ctx.api.named_type(fullname, analyzed_args)
def return_user_model_hook(ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
setting_expr = helpers.get_setting_expr(api, 'AUTH_USER_MODEL')
if setting_expr is None:
return ctx.default_return_type
model_path = get_string_value_from_expr(setting_expr)
if model_path is None:
return ctx.default_return_type
app_label, _, model_class_name = model_path.rpartition('.')
if app_label is None:
return ctx.default_return_type
model_fullname = helpers.get_model_fullname(app_label, model_class_name,
all_modules=api.modules)
if model_fullname is None:
api.fail(f'"{app_label}.{model_class_name}" model class is not imported so far. Try to import it '
f'(under if TYPE_CHECKING) at the beginning of the current file',
context=ctx.context)
return ctx.default_return_type
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
all_modules=api.modules)
if model_info is None or not isinstance(model_info, TypeInfo):
return ctx.default_return_type
return TypeType(Instance(model_info, []))
def _extract_referred_to_type_info(typ: Union[UnionType, Instance]) -> Optional[TypeInfo]:
if isinstance(typ, Instance):
return typ.type
else:
# should be Union[TYPE, None]
typ = helpers.make_required(typ)
if isinstance(typ, Instance):
return typ.type
return None
def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: AttributeContext) -> Type:
if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'):
return ctx.default_attr_type
if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(helpers.MODEL_CLASS_FULLNAME):
return ctx.default_attr_type
field_name = ctx.context.name.split('_')[0]
sym = ctx.type.type.get(field_name)
if sym and isinstance(sym.type, Instance) and len(sym.type.args) > 0:
referred_to = sym.type.args[1]
if isinstance(referred_to, AnyType):
return AnyType(TypeOfAny.implementation_artifact)
model_type = _extract_referred_to_type_info(referred_to)
if model_type is None:
return AnyType(TypeOfAny.implementation_artifact)
primary_key_type = helpers.extract_primary_key_type_for_get(model_type)
if primary_key_type:
return primary_key_type
is_nullable = helpers.get_fields_metadata(ctx.type.type).get(field_name, {}).get('null', False)
if is_nullable:
return helpers.make_optional(ctx.default_attr_type)
return ctx.default_attr_type
def return_integer_type_for_id_for_non_defined_primary_key_in_models(ctx: AttributeContext) -> Type:
if isinstance(ctx.type, Instance) and ctx.type.type.has_base(helpers.MODEL_CLASS_FULLNAME):
return ctx.api.named_generic_type('builtins.int', [])
return ctx.default_attr_type
class ExtractSettingType:
def __init__(self, module_fullname: str):
self.module_fullname = module_fullname
def __call__(self, ctx: AttributeContext) -> Type:
api = cast(TypeChecker, ctx.api)
original_module = api.modules.get(self.module_fullname)
if original_module is None:
return ctx.default_attr_type
definition = ctx.context
if isinstance(definition, MemberExpr):
sym = original_module.names.get(definition.name)
if sym and sym.type:
return sym.type
return ctx.default_attr_type
def transform_form_view(ctx: ClassDefContext) -> None:
form_class_value = helpers.get_assigned_value_for_class(ctx.cls.info, 'form_class')
if isinstance(form_class_value, NameExpr):
helpers.get_django_metadata(ctx.cls.info)['form_class'] = form_class_value.fullname
def extract_proper_type_for_get_form_class(ctx: MethodContext) -> Type:
object_type = ctx.type
if not isinstance(object_type, Instance):
return ctx.default_return_type
form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None)
if not form_class_fullname:
return ctx.default_return_type
return TypeType(ctx.api.named_generic_type(form_class_fullname, []))
def extract_proper_type_for_get_form(ctx: MethodContext) -> Type:
object_type = ctx.type
if not isinstance(object_type, Instance):
return ctx.default_return_type
form_class_type = helpers.get_argument_type_by_name(ctx, 'form_class')
if form_class_type is None or isinstance(form_class_type, NoneTyp):
# extract from specified form_class in metadata
form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None)
if not form_class_fullname:
return ctx.default_return_type
return ctx.api.named_generic_type(form_class_fullname, [])
if isinstance(form_class_type, TypeType) and isinstance(form_class_type.item, Instance):
return form_class_type.item
if isinstance(form_class_type, CallableType) and isinstance(form_class_type.ret_type, Instance):
return form_class_type.ret_type
return ctx.default_return_type
def extract_proper_type_for_values_list(ctx: MethodContext) -> Type:
object_type = ctx.type
if not isinstance(object_type, Instance):
return ctx.default_return_type
flat = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'flat'))
named = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'named'))
ret = ctx.default_return_type
any_type = AnyType(TypeOfAny.implementation_artifact)
if named and flat:
ctx.api.fail("'flat' and 'named' can't be used together.", ctx.context)
return ret
elif named:
# TODO: Fill in namedtuple fields/types
row_arg = ctx.api.named_generic_type('typing.NamedTuple', [])
elif flat:
# TODO: Figure out row_arg type dependent on the argument passed in
if len(ctx.args[0]) > 1:
ctx.api.fail("'flat' is not valid when values_list is called with more than one field.", ctx.context)
return ret
row_arg = any_type
else:
# TODO: Figure out tuple argument types dependent on the arguments passed in
row_arg = ctx.api.named_generic_type('builtins.tuple', [any_type])
first_arg = ret.args[0] if len(ret.args) > 0 else any_type
new_type_args = [first_arg, row_arg]
return helpers.reparametrize_instance(ret, new_type_args)
class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
monkeypatch.restore_original_load_graph()
monkeypatch.restore_original_dependencies_handling()
config_fpath = os.environ.get('MYPY_DJANGO_CONFIG', 'mypy_django.ini')
if config_fpath and os.path.exists(config_fpath):
self.config = Config.from_config_file(config_fpath)
self.django_settings_module = self.config.django_settings_module
else:
self.config = Config()
self.django_settings_module = None
if 'DJANGO_SETTINGS_MODULE' in os.environ:
self.django_settings_module = os.environ['DJANGO_SETTINGS_MODULE']
settings_modules = ['django.conf.global_settings']
if self.django_settings_module:
settings_modules.append(self.django_settings_module)
monkeypatch.add_modules_as_a_source_seed_files(settings_modules)
monkeypatch.inject_modules_as_dependencies_for_django_conf_settings(settings_modules)
def _get_current_model_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (helpers.get_django_metadata(model_sym.node)
.setdefault('model_bases', {helpers.MODEL_CLASS_FULLNAME: 1}))
else:
return {}
def _get_current_manager_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.MANAGER_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (helpers.get_django_metadata(model_sym.node)
.setdefault('manager_bases', {helpers.MANAGER_CLASS_FULLNAME: 1}))
else:
return {}
def _get_current_form_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.BASEFORM_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (helpers.get_django_metadata(model_sym.node)
.setdefault('baseform_bases', {helpers.BASEFORM_CLASS_FULLNAME: 1,
helpers.FORM_CLASS_FULLNAME: 1,
helpers.MODELFORM_CLASS_FULLNAME: 1}))
else:
return {}
def _get_current_queryset_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.QUERYSET_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (helpers.get_django_metadata(model_sym.node)
.setdefault('queryset_bases', {helpers.QUERYSET_CLASS_FULLNAME: 1}))
else:
return {}
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == 'django.contrib.auth.get_user_model':
return return_user_model_hook
manager_bases = self._get_current_manager_bases()
if fullname in manager_bases:
return determine_proper_manager_type
sym = self.lookup_fully_qualified(fullname)
if sym is not None and isinstance(sym.node, TypeInfo):
if sym.node.has_base(helpers.FIELD_FULLNAME):
return fields.adjust_return_type_of_field_instantiation
if sym.node.metadata.get('django', {}).get('generated_init'):
return init_create.redefine_and_typecheck_model_init
def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], Type]]:
class_name, _, method_name = fullname.rpartition('.')
if method_name == 'get_form_class':
sym = self.lookup_fully_qualified(class_name)
if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME):
return extract_proper_type_for_get_form_class
if method_name == 'get_form':
sym = self.lookup_fully_qualified(class_name)
if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME):
return extract_proper_type_for_get_form
if method_name == 'values_list':
sym = self.lookup_fully_qualified(class_name)
if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.QUERYSET_CLASS_FULLNAME):
return extract_proper_type_for_values_list
if fullname in {'django.apps.registry.Apps.get_model',
'django.db.migrations.state.StateApps.get_model'}:
return determine_model_cls_from_string_for_migrations
manager_classes = self._get_current_manager_bases()
class_fullname, _, method_name = fullname.rpartition('.')
if class_fullname in manager_classes and method_name == 'create':
return init_create.redefine_and_typecheck_model_create
return None
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS:
settings_modules = ['django.conf.global_settings']
if self.django_settings_module:
settings_modules.append(self.django_settings_module)
return AddSettingValuesToDjangoConfObject(settings_modules,
self.config.ignore_missing_settings)
if fullname in self._get_current_model_bases():
return transform_model_class
if fullname in self._get_current_manager_bases():
return transform_manager_class
if fullname in self._get_current_form_bases():
return transform_form_class
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME):
return transform_form_view
return None
def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
if fullname == 'builtins.object.id':
return return_integer_type_for_id_for_non_defined_primary_key_in_models
module, _, name = fullname.rpartition('.')
sym = self.lookup_fully_qualified('django.conf.LazySettings')
if sym and isinstance(sym.node, TypeInfo):
metadata = get_settings_metadata(sym.node)
if module == 'builtins.object' and name in metadata:
return ExtractSettingType(module_fullname=metadata[name])
return extract_and_return_primary_key_of_bound_related_field_parameter
def get_type_analyze_hook(self, fullname: str
) -> Optional[Callable[[AnalyzeTypeContext], Type]]:
queryset_bases = self._get_current_queryset_bases()
if fullname in queryset_bases:
return partial(set_first_generic_param_as_default_for_second, fullname)
return None
def plugin(version):
return DjangoPlugin