remove -semanal suffix

This commit is contained in:
Maxim Kurnikov
2019-07-17 21:17:18 +03:00
parent dc6101b569
commit b81fbdeaa9
34 changed files with 21 additions and 20 deletions

View File

View File

View File

@@ -0,0 +1,200 @@
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type, Sequence
from django.core.exceptions import FieldError
from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey, RelatedField
from django.utils.functional import cached_property
from mypy.checker import TypeChecker
from mypy.types import Instance, Type as MypyType
from pytest_mypy.utils import temp_environ
from django.contrib.postgres.fields import ArrayField
from django.db.models.fields import CharField, Field
from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.sql.query import Query
from mypy_django_plugin.lib import helpers
if TYPE_CHECKING:
from django.apps.registry import Apps
from django.conf import LazySettings
@dataclass
class DjangoPluginConfig:
ignore_missing_settings: bool = False
ignore_missing_model_attributes: bool = False
def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']:
with temp_environ():
os.environ['DJANGO_SETTINGS_MODULE'] = settings_module
def noop_class_getitem(cls, key):
return cls
from django.db import models
models.QuerySet.__class_getitem__ = classmethod(noop_class_getitem)
models.Manager.__class_getitem__ = classmethod(noop_class_getitem)
from django.conf import settings
from django.apps import apps
apps.get_models.cache_clear()
apps.get_swappable_settings_name.cache_clear()
apps.populate(settings.INSTALLED_APPS)
assert apps.apps_ready
assert settings.configured
return apps, settings
class DjangoFieldsContext:
def __init__(self, django_context: 'DjangoContext') -> None:
self.django_context = django_context
def get_attname(self, field: Field) -> str:
attname = field.attname
return attname
def get_field_nullability(self, field: Field, method: Optional[str]) -> bool:
nullable = field.null
if not nullable and isinstance(field, CharField) and field.blank:
return True
if method == '__init__':
if field.primary_key or isinstance(field, ForeignKey):
return True
if field.has_default():
return True
return nullable
def get_field_set_type(self, api: TypeChecker, field: Field, method: str) -> MypyType:
target_field = field
if isinstance(field, ForeignKey):
target_field = field.target_field
field_info = helpers.lookup_class_typeinfo(api, target_field.__class__)
field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
is_nullable=self.get_field_nullability(field, method))
if isinstance(target_field, ArrayField):
argument_field_type = self.get_field_set_type(api, target_field.base_field, method)
field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type)
return field_set_type
def get_field_get_type(self, api: TypeChecker, field: Field, method: str) -> MypyType:
field_info = helpers.lookup_class_typeinfo(api, field.__class__)
is_nullable = self.get_field_nullability(field, method)
if isinstance(field, RelatedField):
if method == 'values':
primary_key_field = self.django_context.get_primary_key_field(field.related_model)
return self.get_field_get_type(api, primary_key_field, method)
model_info = helpers.lookup_class_typeinfo(api, field.related_model)
return Instance(model_info, [])
else:
return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
is_nullable=is_nullable)
class DjangoLookupsContext:
def resolve_lookup(self, model_cls: Type[Model], lookup: str) -> Any:
query = Query(model_cls)
lookup_parts, field_parts, is_expression = query.solve_lookup_type(lookup)
if lookup_parts:
raise FieldError('Lookups not supported yet')
currently_observed_model = model_cls
current_field = None
for field_name in field_parts:
current_field = currently_observed_model._meta.get_field(field_name)
if isinstance(current_field, RelatedField):
currently_observed_model = current_field.related_model
return current_field
class DjangoContext:
def __init__(self, plugin_toml_config: Optional[Dict[str, Any]]) -> None:
self.config = DjangoPluginConfig()
self.fields_context = DjangoFieldsContext(self)
self.lookups_context = DjangoLookupsContext()
self.django_settings_module = None
if plugin_toml_config:
self.config.ignore_missing_settings = plugin_toml_config.get('ignore_missing_settings', False)
self.config.ignore_missing_model_attributes = plugin_toml_config.get('ignore_missing_model_attributes', False)
self.django_settings_module = plugin_toml_config.get('django_settings_module', None)
self.apps_registry: Optional[Dict[str, str]] = None
self.settings: LazySettings = None
if self.django_settings_module:
apps, settings = initialize_django(self.django_settings_module)
self.apps_registry = apps
self.settings = settings
@cached_property
def model_modules(self) -> Dict[str, List[Type[Model]]]:
""" All modules that contain Django models. """
if self.apps_registry is None:
return {}
modules: Dict[str, List[Type[Model]]] = defaultdict(list)
for model_cls in self.apps_registry.get_models():
modules[model_cls.__module__].append(model_cls)
return modules
def get_model_class_by_fullname(self, fullname: str) -> Optional[Type[Model]]:
# Returns None if Model is abstract
module, _, model_cls_name = fullname.rpartition('.')
for model_cls in self.model_modules.get(module, []):
if model_cls.__name__ == model_cls_name:
return model_cls
def get_model_fields(self, model_cls: Type[Model]) -> Iterator[Field]:
for field in model_cls._meta.get_fields():
if isinstance(field, Field):
yield field
def get_model_relations(self, model_cls: Type[Model]) -> Iterator[ForeignObjectRel]:
for field in model_cls._meta.get_fields():
if isinstance(field, ForeignObjectRel):
yield field
def get_primary_key_field(self, model_cls: Type[Model]) -> Field:
for field in model_cls._meta.get_fields():
if isinstance(field, Field):
if field.primary_key:
return field
raise ValueError('No primary key defined')
def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], method: str) -> Dict[str, MypyType]:
expected_types = {}
if method == '__init__':
# add pk
primary_key_field = self.get_primary_key_field(model_cls)
field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method)
expected_types['pk'] = field_set_type
for field in self.get_model_fields(model_cls):
field_name = field.attname
field_set_type = self.fields_context.get_field_set_type(api, field, method)
expected_types[field_name] = field_set_type
if isinstance(field, ForeignKey):
field_name = field.name
foreign_key_info = helpers.lookup_class_typeinfo(api, field.__class__)
related_model_info = helpers.lookup_class_typeinfo(api, field.related_model)
is_nullable = self.fields_context.get_field_nullability(field, method)
foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info,
'_pyi_private_set_type',
is_nullable=is_nullable)
model_set_type = helpers.convert_any_to_type(foreign_key_set_type,
Instance(related_model_info, []))
expected_types[field_name] = model_set_type
return expected_types

View File

View File

@@ -0,0 +1,35 @@
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
FIELD_FULLNAME = 'django.db.models.fields.Field'
CHAR_FIELD_FULLNAME = 'django.db.models.fields.CharField'
ARRAY_FIELD_FULLNAME = 'django.contrib.postgres.fields.array.ArrayField'
AUTO_FIELD_FULLNAME = 'django.db.models.fields.AutoField'
GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey'
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField'
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
BASE_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.BaseManager'
MANAGER_CLASS_FULLNAME = 'django.db.models.manager.Manager'
RELATED_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.RelatedManager'
BASEFORM_CLASS_FULLNAME = 'django.forms.forms.BaseForm'
FORM_CLASS_FULLNAME = 'django.forms.forms.Form'
MODELFORM_CLASS_FULLNAME = 'django.forms.models.ModelForm'
FORM_MIXIN_CLASS_FULLNAME = 'django.views.generic.edit.FormMixin'
MANAGER_CLASSES = {
MANAGER_CLASS_FULLNAME,
RELATED_MANAGER_CLASS_FULLNAME,
BASE_MANAGER_CLASS_FULLNAME,
# QUERYSET_CLASS_FULLNAME
}
RELATED_FIELDS_CLASSES = {
FOREIGN_KEY_FULLNAME,
ONETOONE_FIELD_FULLNAME,
MANYTOMANY_FIELD_FULLNAME
}

View File

@@ -0,0 +1,204 @@
from collections import OrderedDict
from typing import Dict, List, Optional, Set, Union, Any
from mypy import checker
from mypy.checker import TypeChecker
from mypy.mro import calculate_mro
from mypy.nodes import Block, ClassDef, Expression, GDEF, MDEF, MypyFile, NameExpr, SymbolNode, SymbolTable, SymbolTableNode, \
TypeInfo, Var
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext
from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault('django', {})
class IncompleteDefnException(Exception):
pass
def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]:
if '.' not in fullname:
return None
module, cls_name = fullname.rsplit('.', 1)
module_file = all_modules.get(module)
if module_file is None:
return None
sym = module_file.names.get(cls_name)
if sym is None:
return None
return sym
def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]:
sym = lookup_fully_qualified_sym(name, all_modules)
if sym is None:
return None
return sym.node
def lookup_fully_qualified_typeinfo(api: TypeChecker, fullname: str) -> Optional[TypeInfo]:
node = lookup_fully_qualified_generic(fullname, api.modules)
if not isinstance(node, TypeInfo):
return None
return node
def lookup_class_typeinfo(api: TypeChecker, klass: type) -> TypeInfo:
fullname = get_class_fullname(klass)
field_info = lookup_fully_qualified_typeinfo(api, fullname)
return field_info
def reparametrize_instance(instance: Instance, new_args: List[MypyType]) -> Instance:
return Instance(instance.type, args=new_args,
line=instance.line, column=instance.column)
def get_class_fullname(klass: type) -> str:
return klass.__module__ + '.' + klass.__qualname__
def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]:
"""
Return the expression for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
args = ctx.args[idx]
if len(args) != 1:
# Either an error or no value passed.
return None
return args[0]
def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]:
"""Return the type for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
return None
idx = ctx.callee_arg_names.index(name)
arg_types = ctx.arg_types[idx]
if len(arg_types) != 1:
# Either an error or no value passed.
return None
return arg_types[0]
def make_optional(typ: MypyType) -> MypyType:
return UnionType.make_union([typ, NoneTyp()])
def parse_bool(expr: Expression) -> Optional[bool]:
if isinstance(expr, NameExpr):
if expr.fullname == 'builtins.True':
return True
if expr.fullname == 'builtins.False':
return False
return None
def has_any_of_bases(info: TypeInfo, bases: Set[str]) -> bool:
for base_fullname in bases:
if info.has_base(base_fullname):
return True
return False
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType:
node = type_info.get(private_field_name).node
if isinstance(node, Var):
descriptor_type = node.type
if is_nullable:
descriptor_type = make_optional(descriptor_type)
return descriptor_type
return AnyType(TypeOfAny.unannotated)
def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]:
metaclass_sym = info.names.get('Meta')
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
return metaclass_sym.node
return None
def add_new_class_for_current_module(api: TypeChecker, name: str, bases: List[Instance],
fields: 'OrderedDict[str, MypyType]') -> TypeInfo:
current_module = api.scope.stack[0]
new_class_unique_name = checker.gen_unique_name(name, current_module.names)
# make new class expression
classdef = ClassDef(new_class_unique_name, Block([]))
classdef.fullname = current_module.fullname() + '.' + new_class_unique_name
# make new TypeInfo
new_typeinfo = TypeInfo(SymbolTable(), classdef, current_module.fullname())
new_typeinfo.bases = bases
calculate_mro(new_typeinfo)
new_typeinfo.calculate_metaclass_type()
def add_field_to_new_typeinfo(var: Var, is_initialized_in_class: bool = False,
is_property: bool = False) -> None:
var.info = new_typeinfo
var.is_initialized_in_class = is_initialized_in_class
var.is_property = is_property
var._fullname = new_typeinfo.fullname() + '.' + var.name()
new_typeinfo.names[var.name()] = SymbolTableNode(MDEF, var)
# add fields
var_items = [Var(item, typ) for item, typ in fields.items()]
for var_item in var_items:
add_field_to_new_typeinfo(var_item, is_property=True)
classdef.info = new_typeinfo
current_module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True)
return new_typeinfo
def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]') -> TupleType:
namedtuple_info = add_new_class_for_current_module(api, name,
bases=[api.named_generic_type('typing.NamedTuple', [])],
fields=fields)
return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, []))
def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType:
# fallback for tuples is any builtins.tuple instance
fallback = api.named_generic_type('builtins.tuple',
[AnyType(TypeOfAny.special_form)])
return TupleType(fields, fallback=fallback)
def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
if isinstance(typ, UnionType):
converted_items = []
for item in typ.items:
converted_items.append(convert_any_to_type(item, referred_to_type))
return UnionType.make_union(converted_items,
line=typ.line, column=typ.column)
if isinstance(typ, Instance):
args = []
for default_arg in typ.args:
if isinstance(default_arg, AnyType):
args.append(referred_to_type)
else:
args.append(default_arg)
return reparametrize_instance(typ, args)
if isinstance(typ, AnyType):
return referred_to_type
return typ
def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyType]',
required_keys: Set[str]) -> TypedDictType:
object_type = api.named_generic_type('mypy_extensions._TypedDict', [])
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
return typed_dict_type

204
mypy_django_plugin/main.py Normal file
View File

@@ -0,0 +1,204 @@
import os
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple
import toml
from django.db.models.fields.related import RelatedField
from mypy.nodes import MypyFile, TypeInfo
from mypy.options import Options
from mypy.plugin import AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import Type as MypyType
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import fields, forms, init_create, querysets, settings
from mypy_django_plugin.transformers.models import process_model_class
def transform_model_class(ctx: ClassDefContext,
django_context: DjangoContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MODEL_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
helpers.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1
else:
if not ctx.api.final_iteration:
ctx.api.defer()
return
process_model_class(ctx, django_context)
def transform_form_class(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.BASEFORM_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
helpers.get_django_metadata(sym.node)['baseform_bases'][ctx.cls.fullname] = 1
forms.make_meta_nested_class_inherit_from_any(ctx)
def add_new_manager_base(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
helpers.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1
class NewSemanalDjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
plugin_toml_config = None
if os.path.exists('pyproject.toml'):
with open('pyproject.toml', 'r') as f:
pyproject_toml = toml.load(f)
plugin_toml_config = pyproject_toml.get('tool', {}).get('django-stubs')
self.django_context = DjangoContext(plugin_toml_config)
def _get_current_queryset_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(fullnames.QUERYSET_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (helpers.get_django_metadata(model_sym.node)
.setdefault('queryset_bases', {fullnames.QUERYSET_CLASS_FULLNAME: 1}))
else:
return {}
def _get_current_manager_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(fullnames.MANAGER_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (helpers.get_django_metadata(model_sym.node)
.setdefault('manager_bases', {fullnames.MANAGER_CLASS_FULLNAME: 1}))
else:
return {}
def _get_current_model_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(fullnames.MODEL_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return helpers.get_django_metadata(model_sym.node).setdefault('model_bases',
{fullnames.MODEL_CLASS_FULLNAME: 1})
else:
return {}
def _get_current_form_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(fullnames.BASEFORM_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (helpers.get_django_metadata(model_sym.node)
.setdefault('baseform_bases', {fullnames.BASEFORM_CLASS_FULLNAME: 1,
fullnames.FORM_CLASS_FULLNAME: 1,
fullnames.MODELFORM_CLASS_FULLNAME: 1}))
else:
return {}
def _get_typeinfo_or_none(self, class_name: str) -> Optional[TypeInfo]:
sym = self.lookup_fully_qualified(class_name)
if sym is not None and isinstance(sym.node, TypeInfo):
return sym.node
return None
def _new_dependency(self, module: str) -> Tuple[int, str, int]:
return 10, module, -1
def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]:
# for settings
if file.fullname() == 'django.conf' and self.django_context.django_settings_module:
return [self._new_dependency(self.django_context.django_settings_module)]
# for values / values_list
if file.fullname() == 'django.db.models':
return [self._new_dependency('mypy_extensions'), self._new_dependency('typing')]
# for `get_user_model()`
if self.django_context.settings:
if file.fullname() == 'django.contrib.auth':
auth_user_model_name = self.django_context.settings.AUTH_USER_MODEL
try:
auth_user_module = self.django_context.apps_registry.get_model(auth_user_model_name).__module__
except LookupError:
# get_user_model() model app is not installed
return []
return [self._new_dependency(auth_user_module)]
# ensure that all mentioned to='someapp.SomeModel' are loaded with corresponding related Fields
defined_model_classes = self.django_context.model_modules.get(file.fullname())
if not defined_model_classes:
return []
deps = set()
for model_class in defined_model_classes:
# forward relations
for field in self.django_context.get_model_fields(model_class):
if isinstance(field, RelatedField):
related_model_module = field.related_model.__module__
if related_model_module != file.fullname():
deps.add(self._new_dependency(related_model_module))
# reverse relations
for relation in model_class._meta.related_objects:
related_model_module = relation.related_model.__module__
if related_model_module != file.fullname():
deps.add(self._new_dependency(related_model_module))
return list(deps)
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], MypyType]]:
if fullname == 'django.contrib.auth.get_user_model':
return partial(settings.get_user_model_hook, django_context=self.django_context)
manager_bases = self._get_current_manager_bases()
if fullname in manager_bases:
return querysets.determine_proper_manager_type
info = self._get_typeinfo_or_none(fullname)
if info:
if info.has_base(fullnames.FIELD_FULLNAME):
return partial(fields.transform_into_proper_return_type, django_context=self.django_context)
if info.has_base(fullnames.MODEL_CLASS_FULLNAME):
return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)
def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], MypyType]]:
class_fullname, _, method_name = fullname.rpartition('.')
if method_name == 'get_form_class':
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return forms.extract_proper_type_for_get_form_class
if method_name == 'get_form':
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return forms.extract_proper_type_for_get_form
if method_name == 'values':
model_info = self._get_typeinfo_or_none(class_fullname)
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
if method_name == 'values_list':
model_info = self._get_typeinfo_or_none(class_fullname)
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
manager_classes = self._get_current_manager_bases()
if class_fullname in manager_classes and method_name == 'create':
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in self._get_current_model_bases():
return partial(transform_model_class, django_context=self.django_context)
if fullname in self._get_current_manager_bases():
return add_new_manager_base
if fullname in self._get_current_form_bases():
return transform_form_class
def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], MypyType]]:
class_name, _, attr_name = fullname.rpartition('.')
if class_name == fullnames.DUMMY_SETTINGS_BASE_CLASS:
return partial(settings.get_type_of_settings_attribute,
django_context=self.django_context)
def plugin(version):
return NewSemanalDjangoPlugin

View File

@@ -0,0 +1,258 @@
from typing import Optional, Tuple, cast
from mypy.checker import TypeChecker
from mypy.nodes import StrExpr, TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, Type as MypyType, UnionType
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
def extract_referred_to_type(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Instance]:
api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.callee_arg_names:
api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname!r}',
context=ctx.context)
return None
arg_type = ctx.arg_types[ctx.callee_arg_names.index('to')][0]
if not isinstance(arg_type, CallableType):
to_arg_expr = ctx.args[ctx.callee_arg_names.index('to')][0]
if not isinstance(to_arg_expr, StrExpr):
# not string, not supported
return None
model_string = to_arg_expr.value
if model_string == 'self':
model_fullname = api.tscope.classes[-1].fullname()
elif '.' not in model_string:
model_fullname = api.tscope.classes[-1].module_name + '.' + model_string
else:
if django_context.app_models is not None and model_string in django_context.app_models:
model_fullname = django_context.app_models[model_string]
else:
ctx.api.fail(f'Cannot find referenced model for {model_string!r}', context=ctx.context)
return None
model_info = helpers.lookup_fully_qualified_generic(model_fullname, all_modules=api.modules)
if model_info is None or not isinstance(model_info, TypeInfo):
raise helpers.IncompleteDefnException(model_fullname)
return Instance(model_info, [])
referred_to_type = arg_type.ret_type
assert isinstance(referred_to_type, Instance)
if not referred_to_type.type.has_base(fullnames.MODEL_CLASS_FULLNAME):
ctx.api.msg.fail(f'to= parameter value must be a subclass of {fullnames.MODEL_CLASS_FULLNAME!r}',
context=ctx.context)
return None
return referred_to_type
def convert_any_to_type(typ: MypyType, replacement_type: MypyType) -> MypyType:
"""
Converts any encountered Any (in typ itself, or in generic parameters) into referred_to_type
"""
if isinstance(typ, UnionType):
converted_items = []
for item in typ.items:
converted_items.append(convert_any_to_type(item, replacement_type))
return UnionType.make_union(converted_items,
line=typ.line, column=typ.column)
if isinstance(typ, Instance):
args = []
for default_arg in typ.args:
if isinstance(default_arg, AnyType):
args.append(replacement_type)
else:
args.append(default_arg)
return helpers.reparametrize_instance(typ, args)
if isinstance(typ, AnyType):
return replacement_type
return typ
def get_referred_to_model_fullname(ctx: FunctionContext, django_context: DjangoContext) -> str:
to_arg_type = helpers.get_call_argument_type_by_name(ctx, 'to')
if isinstance(to_arg_type, CallableType):
assert isinstance(to_arg_type.ret_type, Instance)
return to_arg_type.ret_type.type.fullname()
to_arg_expr = helpers.get_call_argument_by_name(ctx, 'to')
if not isinstance(to_arg_expr, StrExpr):
raise helpers.IncompleteDefnException(f'Not a string: {to_arg_expr}')
outer_model_info = ctx.api.tscope.classes[-1]
assert isinstance(outer_model_info, TypeInfo)
model_string = to_arg_expr.value
if model_string == 'self':
return outer_model_info.fullname()
if '.' not in model_string:
# same file class
return outer_model_info.module_name + '.' + model_string
model_cls = django_context.apps_registry.get_model(model_string)
model_fullname = helpers.get_class_fullname(model_cls)
return model_fullname
def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
referred_to_fullname = get_referred_to_model_fullname(ctx, django_context)
referred_to_typeinfo = helpers.lookup_fully_qualified_generic(referred_to_fullname, ctx.api.modules)
assert isinstance(referred_to_typeinfo, TypeInfo), f'Cannot resolve {referred_to_fullname!r}'
referred_to_type = Instance(referred_to_typeinfo, [])
default_related_field_type = set_descriptor_types_for_field(ctx)
# replace Any with referred_to_type
args = []
for default_arg in default_related_field_type.args:
args.append(convert_any_to_type(default_arg, referred_to_type))
return helpers.reparametrize_instance(default_related_field_type, new_args=args)
def get_field_descriptor_types(field_info: TypeInfo, is_nullable: bool) -> Tuple[MypyType, MypyType]:
set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
is_nullable=is_nullable)
get_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
is_nullable=is_nullable)
return set_type, get_type
def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
default_return_type = cast(Instance, ctx.default_return_type)
is_nullable = helpers.parse_bool(helpers.get_call_argument_by_name(ctx, 'null'))
set_type, get_type = get_field_descriptor_types(default_return_type.type, is_nullable)
return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
def transform_into_proper_return_type(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
default_return_type = ctx.default_return_type
assert isinstance(default_return_type, Instance)
if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):
return fill_descriptor_types_for_related_field(ctx, django_context)
if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME):
return determine_type_of_array_field(ctx, django_context)
return set_descriptor_types_for_field(ctx)
def determine_type_of_array_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
default_return_type = set_descriptor_types_for_field(ctx)
base_field_arg_type = helpers.get_call_argument_type_by_name(ctx, 'base_field')
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
return default_return_type
base_type = base_field_arg_type.args[1] # extract __get__ type
args = []
for default_arg in default_return_type.args:
args.append(convert_any_to_type(default_arg, base_type))
return helpers.reparametrize_instance(default_return_type, args)
# def _parse_choices_type(ctx: FunctionContext, choices_arg: Expression) -> Optional[str]:
# if isinstance(choices_arg, (TupleExpr, ListExpr)):
# # iterable of 2 element tuples of two kinds
# _, analyzed_choices = ctx.api.analyze_iterable_item_type(choices_arg)
# if isinstance(analyzed_choices, TupleType):
# first_element_type = analyzed_choices.items[0]
# if isinstance(first_element_type, Instance):
# return first_element_type.type.fullname()
# def _parse_referenced_model(ctx: FunctionContext, to_arg: Expression) -> Optional[TypeInfo]:
# if isinstance(to_arg, NameExpr) and isinstance(to_arg.node, TypeInfo):
# # reference to the model class
# return to_arg.node
#
# elif isinstance(to_arg, StrExpr):
# referenced_model_info = helpers.get_model_info(to_arg.value, ctx.api.modules)
# if referenced_model_info is not None:
# return referenced_model_info
# def parse_field_init_arguments_into_model_metadata(ctx: FunctionContext) -> None:
# outer_model = ctx.api.scope.active_class()
# if outer_model is None or not outer_model.has_base(fullnames.MODEL_CLASS_FULLNAME):
# # outside models.Model class, undetermined
# return
#
# # Determine name of the current field
# for attr_name, stmt in helpers.iter_over_class_level_assignments(outer_model.defn):
# if stmt == ctx.context:
# field_name = attr_name
# break
# else:
# return
#
# model_fields_metadata = metadata.get_fields_metadata(outer_model)
#
# # primary key
# is_primary_key = False
# primary_key_arg = helpers.get_call_argument_by_name(ctx, 'primary_key')
# if primary_key_arg:
# is_primary_key = helpers.parse_bool(primary_key_arg)
# model_fields_metadata[field_name] = {'primary_key': is_primary_key}
#
# # choices
# choices_arg = helpers.get_call_argument_by_name(ctx, 'choices')
# if choices_arg:
# choices_type_fullname = _parse_choices_type(ctx.api, choices_arg)
# if choices_type_fullname:
# model_fields_metadata[field_name]['choices_type'] = choices_type_fullname
#
# # nullability
# null_arg = helpers.get_call_argument_by_name(ctx, 'null')
# is_nullable = False
# if null_arg:
# is_nullable = helpers.parse_bool(null_arg)
# model_fields_metadata[field_name]['null'] = is_nullable
#
# # is_blankable
# blank_arg = helpers.get_call_argument_by_name(ctx, 'blank')
# is_blankable = False
# if blank_arg:
# is_blankable = helpers.parse_bool(blank_arg)
# model_fields_metadata[field_name]['blank'] = is_blankable
#
# # default
# default_arg = helpers.get_call_argument_by_name(ctx, 'default')
# if default_arg and not helpers.is_none_expr(default_arg):
# model_fields_metadata[field_name]['default_specified'] = True
#
# if helpers.has_any_of_bases(ctx.default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):
# # to
# to_arg = helpers.get_call_argument_by_name(ctx, 'to')
# if to_arg:
# referenced_model = _parse_referenced_model(ctx, to_arg)
# if referenced_model is not None:
# model_fields_metadata[field_name]['to'] = referenced_model.fullname()
# else:
# model_fields_metadata[field_name]['to'] = to_arg.value
# # referenced_model = to_arg.value
# # raise helpers.IncompleteDefnException()
#
# # model_fields_metadata[field_name]['to'] = referenced_model.fullname()
# # if referenced_model is not None:
# # model_fields_metadata[field_name]['to'] = referenced_model.fullname()
# # else:
# # assert isinstance(to_arg, StrExpr)
# # model_fields_metadata[field_name]['to'] = to_arg.value
#
# # related_name
# related_name_arg = helpers.get_call_argument_by_name(ctx, 'related_name')
# if related_name_arg:
# if isinstance(related_name_arg, StrExpr):
# model_fields_metadata[field_name]['related_name'] = related_name_arg.value
# else:
# model_fields_metadata[field_name]['related_name'] = outer_model.name().lower() + '_set'

View File

@@ -0,0 +1,50 @@
from typing import Optional
from mypy.plugin import ClassDefContext, MethodContext
from mypy.types import CallableType, Instance, NoneTyp, Type as MypyType, TypeType
from mypy_django_plugin.lib import helpers
def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None:
meta_node = helpers.get_nested_meta_node_for_current_class(ctx.cls.info)
if meta_node is None:
if not ctx.api.final_iteration:
ctx.api.defer()
else:
meta_node.fallback_to_any = True
def get_specified_form_class(object_type: Instance) -> Optional[TypeType]:
form_class_sym = object_type.type.get('form_class')
if form_class_sym and isinstance(form_class_sym.type, CallableType):
return TypeType(form_class_sym.type.ret_type)
return None
def extract_proper_type_for_get_form(ctx: MethodContext) -> MypyType:
object_type = ctx.type
assert isinstance(object_type, Instance)
form_class_type = helpers.get_call_argument_type_by_name(ctx, 'form_class')
if form_class_type is None or isinstance(form_class_type, NoneTyp):
form_class_type = get_specified_form_class(object_type)
if isinstance(form_class_type, TypeType) and isinstance(form_class_type.item, Instance):
return form_class_type.item
if isinstance(form_class_type, CallableType) and isinstance(form_class_type.ret_type, Instance):
return form_class_type.ret_type
return ctx.default_return_type
def extract_proper_type_for_get_form_class(ctx: MethodContext) -> MypyType:
object_type = ctx.type
assert isinstance(object_type, Instance)
form_class_type = get_specified_form_class(object_type)
if form_class_type is None:
return ctx.default_return_type
return form_class_type

View File

@@ -0,0 +1,69 @@
from typing import List, Tuple, Type, Union
from django.db.models.base import Model
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import Instance, Type as MypyType
from mypy_django_plugin.django.context import DjangoContext
def get_actual_types(ctx: Union[MethodContext, FunctionContext], expected_keys: List[str]) -> List[Tuple[str, MypyType]]:
actual_types = []
# positionals
for pos, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
if actual_name is None:
if ctx.callee_arg_names[0] == 'kwargs':
# unpacked dict as kwargs is not supported
continue
actual_name = expected_keys[pos]
actual_types.append((actual_name, actual_type))
# kwargs
if len(ctx.callee_arg_names) > 1:
for actual_name, actual_type in zip(ctx.arg_names[1], ctx.arg_types[1]):
if actual_name is None:
# unpacked dict as kwargs is not supported
continue
actual_types.append((actual_name, actual_type))
return actual_types
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
model_cls: Type[Model], method: str) -> MypyType:
expected_types = django_context.get_expected_types(ctx.api, model_cls, method)
expected_keys = [key for key in expected_types.keys() if key != 'pk']
for actual_name, actual_type in get_actual_types(ctx, expected_keys):
if actual_name not in expected_types:
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model_cls.__name__),
ctx.context)
continue
ctx.api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model_cls.__name__),
'got', 'expected')
return ctx.default_return_type
def redefine_and_typecheck_model_init(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
assert isinstance(ctx.default_return_type, Instance)
model_fullname = ctx.default_return_type.type.fullname()
model_cls = django_context.get_model_class_by_fullname(model_fullname)
if model_cls is None:
return ctx.default_return_type
return typecheck_model_method(ctx, django_context, model_cls, '__init__')
def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
isinstance(ctx.default_return_type, Instance)
model_fullname = ctx.default_return_type.type.fullname()
model_cls = django_context.get_model_class_by_fullname(model_fullname)
if model_cls is None:
return ctx.default_return_type
return typecheck_model_method(ctx, django_context, model_cls, 'create')

View File

@@ -0,0 +1,162 @@
import dataclasses
from abc import ABCMeta, abstractmethod
from typing import cast
from django.db.models.fields.related import ForeignKey
from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel
from mypy.newsemanal.semanal import NewSemanticAnalyzer
from mypy.nodes import ClassDef, MDEF, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext
from mypy.types import Instance
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import fields
from mypy_django_plugin.transformers.fields import get_field_descriptor_types
@dataclasses.dataclass
class ModelClassInitializer(metaclass=ABCMeta):
api: NewSemanticAnalyzer
model_classdef: ClassDef
django_context: DjangoContext
ctx: ClassDefContext
@classmethod
def from_ctx(cls, ctx: ClassDefContext, django_context: DjangoContext):
return cls(api=cast(NewSemanticAnalyzer, ctx.api),
model_classdef=ctx.cls,
django_context=django_context,
ctx=ctx)
def lookup_typeinfo_or_incomplete_defn_error(self, fullname: str) -> TypeInfo:
sym = self.api.lookup_fully_qualified_or_none(fullname)
if sym is None or not isinstance(sym.node, TypeInfo):
raise helpers.IncompleteDefnException(f'No {fullname!r} found')
return sym.node
def lookup_class_typeinfo_or_incomplete_defn_error(self, klass: type) -> TypeInfo:
fullname = helpers.get_class_fullname(klass)
field_info = self.lookup_typeinfo_or_incomplete_defn_error(fullname)
return field_info
def add_new_node_to_model_class(self, name: str, typ: Instance) -> None:
# type=: type of the variable itself
var = Var(name=name, type=typ)
# var.info: type of the object variable is bound to
var.info = self.model_classdef.info
var._fullname = self.model_classdef.info.fullname() + '.' + name
var.is_initialized_in_class = True
var.is_inferred = True
self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True)
@abstractmethod
def run(self) -> None:
raise NotImplementedError()
class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):
"""
Replaces
class MyModel(models.Model):
class Meta:
pass
with
class MyModel(models.Model):
class Meta(Any):
pass
to get around incompatible Meta inner classes for different models.
"""
def run(self) -> None:
meta_node = helpers.get_nested_meta_node_for_current_class(self.model_classdef.info)
if meta_node is None:
return None
meta_node.fallback_to_any = True
class AddDefaultPrimaryKey(ModelClassInitializer):
def run(self) -> None:
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname)
if model_cls is None:
return
auto_field = model_cls._meta.auto_field
if auto_field and not self.model_classdef.info.has_readable_member(auto_field.attname):
# autogenerated field
auto_field_fullname = helpers.get_class_fullname(auto_field.__class__)
auto_field_info = self.lookup_typeinfo_or_incomplete_defn_error(auto_field_fullname)
set_type, get_type = fields.get_field_descriptor_types(auto_field_info, is_nullable=False)
self.add_new_node_to_model_class(auto_field.attname, Instance(auto_field_info,
[set_type, get_type]))
class AddRelatedModelsId(ModelClassInitializer):
def run(self) -> None:
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname)
if model_cls is None:
return
for field in model_cls._meta.get_fields():
if isinstance(field, ForeignKey):
rel_primary_key_field = self.django_context.get_primary_key_field(field.related_model)
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__)
is_nullable = self.django_context.fields_context.get_field_nullability(field, None)
set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
self.add_new_node_to_model_class(field.attname,
Instance(field_info, [set_type, get_type]))
class AddManagers(ModelClassInitializer):
def run(self):
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname)
if model_cls is None:
return
for manager_name, manager in model_cls._meta.managers_map.items():
if manager_name not in self.model_classdef.info.names:
manager_fullname = helpers.get_class_fullname(manager.__class__)
manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname)
manager = Instance(manager_info, [Instance(self.model_classdef.info, [])])
self.add_new_node_to_model_class(manager_name, manager)
# add _default_manager
if '_default_manager' not in self.model_classdef.info.names:
default_manager_fullname = helpers.get_class_fullname(model_cls._meta.default_manager.__class__)
default_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(default_manager_fullname)
default_manager = Instance(default_manager_info, [Instance(self.model_classdef.info, [])])
self.add_new_node_to_model_class('_default_manager', default_manager)
# add related managers
for relation in self.django_context.get_model_relations(model_cls):
attname = relation.related_name
if attname is None:
attname = relation.name + '_set'
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(relation.related_model)
if isinstance(relation, OneToOneRel):
self.add_new_node_to_model_class(attname, Instance(related_model_info, []))
continue
if isinstance(relation, (ManyToOneRel, ManyToManyRel)):
manager_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.RELATED_MANAGER_CLASS_FULLNAME)
self.add_new_node_to_model_class(attname,
Instance(manager_info, [Instance(related_model_info, [])]))
continue
def process_model_class(ctx: ClassDefContext,
django_context: DjangoContext) -> None:
initializers = [
InjectAnyAsBaseForNestedMeta,
AddDefaultPrimaryKey,
AddRelatedModelsId,
AddManagers,
]
for initializer_cls in initializers:
try:
initializer_cls.from_ctx(ctx, django_context).run()
except helpers.IncompleteDefnException:
if not ctx.api.final_iteration:
ctx.api.defer()

View File

@@ -0,0 +1,166 @@
from collections import OrderedDict
from typing import Optional, Tuple, Type
from django.core.exceptions import FieldError
from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey
from mypy.nodes import NameExpr
from mypy.plugin import AnalyzeTypeContext, FunctionContext, MethodContext
from mypy.types import AnyType, Instance, Type as MypyType, TypeOfAny
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
def set_first_generic_param_as_default_for_second(ctx: AnalyzeTypeContext, fullname: str) -> MypyType:
if not ctx.type.args:
try:
return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit),
AnyType(TypeOfAny.explicit)])
except KeyError:
# really should never happen
return AnyType(TypeOfAny.explicit)
args = ctx.type.args
if len(args) == 1:
args = [args[0], args[0]]
analyzed_args = [ctx.api.analyze_type(arg) for arg in args]
ctx.api.analyze_type(ctx.type)
try:
return ctx.api.named_type(fullname, analyzed_args)
except KeyError:
return AnyType(TypeOfAny.explicit)
def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
ret = ctx.default_return_type
assert isinstance(ret, Instance)
if not ctx.api.tscope.classes:
# not in class
return ret
outer_model_info = ctx.api.tscope.classes[0]
if not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
return ret
return helpers.reparametrize_instance(ret, [Instance(outer_model_info, [])])
def get_lookup_field_get_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
lookup: str, method: str) -> Optional[Tuple[str, MypyType]]:
try:
lookup_field = django_context.lookups_context.resolve_lookup(model_cls, lookup)
except FieldError as exc:
ctx.api.fail(exc.args[0], ctx.context)
return None
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, method)
return lookup_field.attname, field_get_type
def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
flat: bool, named: bool) -> MypyType:
field_lookups = [expr.value for expr in ctx.args[0]]
if len(field_lookups) == 0:
if flat:
primary_key_field = django_context.get_primary_key_field(model_cls)
_, field_get_type = get_lookup_field_get_type(ctx, django_context, model_cls,
primary_key_field.attname, 'values_list')
return field_get_type
elif named:
column_types = OrderedDict()
for field in django_context.get_model_fields(model_cls):
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, field, 'values_list')
column_types[field.attname] = field_get_type
return helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types)
else:
# flat=False, named=False, all fields
field_lookups = []
for field in model_cls._meta.get_fields():
field_lookups.append(field.attname)
if len(field_lookups) > 1 and flat:
ctx.api.fail("'flat' is not valid when 'values_list' is called with more than one field", ctx.context)
return AnyType(TypeOfAny.from_error)
column_types = OrderedDict()
for field_lookup in field_lookups:
result = get_lookup_field_get_type(ctx, django_context, model_cls, field_lookup, 'values_list')
if result is None:
return AnyType(TypeOfAny.from_error)
field_name, field_get_type = result
column_types[field_name] = field_get_type
if flat:
assert len(column_types) == 1
row_type = next(iter(column_types.values()))
elif named:
row_type = helpers.make_oneoff_named_tuple(ctx.api, 'Row', column_types)
else:
row_type = helpers.make_tuple(ctx.api, list(column_types.values()))
return row_type
def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
assert isinstance(ctx.type, Instance)
assert isinstance(ctx.type.args[0], Instance)
model_type = ctx.type.args[0]
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
if model_cls is None:
return ctx.default_return_type
flat_expr = helpers.get_call_argument_by_name(ctx, 'flat')
if flat_expr is not None and isinstance(flat_expr, NameExpr):
flat = helpers.parse_bool(flat_expr)
else:
flat = False
named_expr = helpers.get_call_argument_by_name(ctx, 'named')
if named_expr is not None and isinstance(named_expr, NameExpr):
named = helpers.parse_bool(named_expr)
else:
named = False
if flat and named:
ctx.api.fail("'flat' and 'named' can't be used together", ctx.context)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
row_type = get_values_list_row_type(ctx, django_context, model_cls,
flat=flat, named=named)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
assert isinstance(ctx.type, Instance)
assert isinstance(ctx.type.args[0], Instance)
model_type = ctx.type.args[0]
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname())
if model_cls is None:
return ctx.default_return_type
field_lookups = [expr.value for expr in ctx.args[0]]
if len(field_lookups) == 0:
for field in model_cls._meta.get_fields():
field_lookups.append(field.attname)
column_types = OrderedDict()
for field_lookup in field_lookups:
try:
lookup_field = django_context.lookups_context.resolve_lookup(model_cls, field_lookup)
except FieldError as exc:
ctx.api.fail(exc.args[0], ctx.context)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
field_get_type = django_context.fields_context.get_field_get_type(ctx.api, lookup_field, 'values')
field_name = lookup_field.attname
if isinstance(lookup_field, ForeignKey) and field_lookup == lookup_field.name:
field_name = lookup_field.name
column_types[field_name] = field_get_type
row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys()))
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])

View File

@@ -0,0 +1,45 @@
from mypy.nodes import TypeInfo, MemberExpr
from mypy.plugin import FunctionContext, AttributeContext
from mypy.types import Type as MypyType, TypeType, Instance
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import helpers
def get_user_model_hook(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
auth_user_model = django_context.settings.AUTH_USER_MODEL
model_cls = django_context.apps_registry.get_model(auth_user_model)
model_cls_fullname = helpers.get_class_fullname(model_cls)
model_info = helpers.lookup_fully_qualified_generic(model_cls_fullname, ctx.api.modules)
assert isinstance(model_info, TypeInfo)
return TypeType(Instance(model_info, []))
def get_type_of_settings_attribute(ctx: AttributeContext, django_context: DjangoContext) -> MypyType:
assert isinstance(ctx.context, MemberExpr)
setting_name = ctx.context.name
if not hasattr(django_context.settings, setting_name):
ctx.api.fail(f"'Settings' object has no attribute {setting_name!r}", ctx.context)
return ctx.default_attr_type
# first look for the setting in the project settings file, then global settings
settings_module = ctx.api.modules.get(django_context.django_settings_module)
global_settings_module = ctx.api.modules.get('django.conf.global_settings')
for module in [settings_module, global_settings_module]:
if module is not None:
sym = module.names.get(setting_name)
if sym is not None and sym.type is not None:
return sym.type
# if by any reason it isn't present there, get type from django settings
value = getattr(django_context.settings, setting_name)
value_fullname = helpers.get_class_fullname(value.__class__)
value_info = helpers.lookup_fully_qualified_typeinfo(ctx.api, value_fullname)
if value_info is None:
return ctx.default_attr_type
return Instance(value_info, [])