new semanal wip 1

This commit is contained in:
Maxim Kurnikov
2019-07-16 01:22:20 +03:00
parent 9c5a6be9a7
commit b11a9a85f9
96 changed files with 4441 additions and 2370 deletions

View File

@@ -1,12 +1,17 @@
import os
from configparser import ConfigParser
from typing import Optional
from typing import Dict, List, Optional
import dataclasses
from dataclasses import dataclass
from pytest_mypy.utils import temp_environ
@dataclass
class Config:
django_settings_module: Optional[str] = None
installed_apps: List[str] = dataclasses.field(default_factory=list)
ignore_missing_settings: bool = False
ignore_missing_model_attributes: bool = False
@@ -29,3 +34,21 @@ class Config:
ignore_missing_model_attributes=bool(ini_config.get('mypy_django_plugin',
'ignore_missing_model_attributes',
fallback=False)))
def extract_app_model_aliases(settings_module: str) -> Dict[str, str]:
with temp_environ():
os.environ['DJANGO_SETTINGS_MODULE'] = settings_module
import django
django.setup()
app_model_mapping: Dict[str, str] = {}
from django.apps import apps
for name, app_config in apps.app_configs.items():
app_label = app_config.label
for model_name, model_class in app_config.models.items():
app_model_mapping[app_label + '.' + model_class.__name__] = model_class.__module__ + '.' + model_class.__name__
return app_model_mapping

View File

@@ -26,4 +26,10 @@ MANAGER_CLASSES = {
RELATED_MANAGER_CLASS_FULLNAME,
BASE_MANAGER_CLASS_FULLNAME,
QUERYSET_CLASS_FULLNAME
}
RELATED_FIELDS_CLASSES = {
FOREIGN_KEY_FULLNAME,
ONETOONE_FIELD_FULLNAME,
MANYTOMANY_FIELD_FULLNAME
}

View File

@@ -1,24 +1,20 @@
import typing
from collections import OrderedDict
from typing import Dict, Optional, cast
from typing import Dict, Iterator, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast
from mypy.mro import calculate_mro
from mypy.nodes import (
GDEF, MDEF, AssignmentStmt, Block, CallExpr, ClassDef, Expression, ImportedName, Lvalue, MypyFile, NameExpr,
SymbolNode, SymbolTable, SymbolTableNode, TypeInfo, Var,
)
from mypy.nodes import (AssignmentStmt, Block, CallExpr, ClassDef, Expression, FakeInfo, GDEF, ImportedName, Lvalue, MDEF,
MemberExpr, MypyFile, NameExpr, SymbolNode, SymbolTable, SymbolTableNode, TypeInfo, Var)
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext
from mypy.types import (
AnyType, Instance, NoneTyp, TupleType, Type, TypedDictType, TypeOfAny, TypeVarType, UnionType,
)
from mypy.types import (AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypeVarType, TypedDictType,
UnionType)
from mypy_django_plugin.lib import metadata, fullnames
from mypy_django_plugin.lib import fullnames, metadata
if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from mypy.checker import TypeChecker
def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]:
def get_models_file(app_name: str, all_modules: Dict[str, MypyFile]) -> Optional[MypyFile]:
models_module = '.'.join([app_name, 'models'])
return all_modules.get(models_module)
@@ -85,7 +81,7 @@ def parse_bool(expr: Expression) -> Optional[bool]:
return None
def reparametrize_instance(instance: Instance, new_args: typing.List[Type]) -> Instance:
def reparametrize_instance(instance: Instance, new_args: List[MypyType]) -> Instance:
return Instance(instance.type, args=new_args,
line=instance.line, column=instance.column)
@@ -94,7 +90,7 @@ def fill_typevars_with_any(instance: Instance) -> Instance:
return reparametrize_instance(instance, [AnyType(TypeOfAny.unannotated)])
def extract_typevar_value(tp: Instance, typevar_name: str) -> Type:
def extract_typevar_value(tp: Instance, typevar_name: str) -> MypyType:
if typevar_name in {'_T', '_T_co'}:
if '_T' in tp.type.type_vars:
return tp.args[tp.type.type_vars.index('_T')]
@@ -104,16 +100,16 @@ def extract_typevar_value(tp: Instance, typevar_name: str) -> Type:
def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance:
typevar_values: typing.List[Type] = []
typevar_values: List[MypyType] = []
for typevar_arg in type_to_fill.args:
if isinstance(typevar_arg, TypeVarType):
typevar_values.append(extract_typevar_value(tp, typevar_arg.name))
return Instance(type_to_fill.type, typevar_values)
def get_argument_by_name(ctx: typing.Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]:
"""Return the expression for the specific argument.
def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]:
"""
Return the expression for the specific argument.
This helper should only be used with non-star arguments.
"""
if name not in ctx.callee_arg_names:
@@ -126,7 +122,7 @@ def get_argument_by_name(ctx: typing.Union[FunctionContext, MethodContext], name
return args[0]
def get_argument_type_by_name(ctx: typing.Union[FunctionContext, MethodContext], name: str) -> Optional[Type]:
def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]:
"""Return the type for the specific argument.
This helper should only be used with non-star arguments.
@@ -157,14 +153,38 @@ def get_setting_expr(api: 'TypeChecker', setting_name: str) -> Optional[Expressi
return None
module_file = api.modules.get(module)
for name_expr, value_expr in iter_over_assignments(module_file):
for name_expr, value_expr in iter_over_module_level_assignments(module_file):
if isinstance(name_expr, NameExpr) and name_expr.name == setting_name:
return value_expr
return None
def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]
) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
def iter_over_class_level_assignments(klass: ClassDef) -> Iterator[Tuple[str, Expression]]:
for stmt in klass.defs.body:
if not isinstance(stmt, AssignmentStmt):
continue
if len(stmt.lvalues) > 1:
# skip multiple assignments
continue
lvalue = stmt.lvalues[0]
if isinstance(lvalue, NameExpr):
yield lvalue.name, stmt.rvalue
def iter_over_module_level_assignments(module: MypyFile) -> Iterator[Tuple[str, Expression]]:
for stmt in module.defs:
if not isinstance(stmt, AssignmentStmt):
continue
if len(stmt.lvalues) > 1:
# skip multiple assignments
continue
lvalue = stmt.lvalues[0]
if isinstance(lvalue, NameExpr):
yield lvalue.name, stmt.rvalue
def iter_over_assignments_in_class(class_or_module: Union[ClassDef, MypyFile]
) -> Iterator[Tuple[str, Expression]]:
if isinstance(class_or_module, ClassDef):
statements = class_or_module.defs.body
else:
@@ -176,10 +196,12 @@ def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]
if len(stmt.lvalues) > 1:
# not supported yet
continue
yield stmt.lvalues[0], stmt.rvalue
lvalue = stmt.lvalues[0]
if isinstance(lvalue, NameExpr):
yield lvalue.name, stmt.rvalue
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
def extract_field_setter_type(tp: Instance) -> Optional[MypyType]:
""" Extract __set__ value of a field. """
if tp.type.has_base(fullnames.FIELD_FULLNAME):
return tp.args[0]
@@ -189,7 +211,7 @@ def extract_field_setter_type(tp: Instance) -> Optional[Type]:
return None
def extract_field_getter_type(tp: Type) -> Optional[Type]:
def extract_field_getter_type(tp: MypyType) -> Optional[MypyType]:
""" Extract return type of __get__ of subclass of Field"""
if not isinstance(tp, Instance):
return None
@@ -201,7 +223,7 @@ def extract_field_getter_type(tp: Type) -> Optional[Type]:
return None
def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[Type]:
def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[MypyType]:
"""
If field with primary_key=True is set on the model, extract its __set__ type.
"""
@@ -212,7 +234,7 @@ def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[
return None
def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]:
def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[MypyType]:
for field_name, props in metadata.get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
@@ -220,11 +242,11 @@ def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]:
return None
def make_optional(typ: Type):
def make_optional(typ: MypyType) -> MypyType:
return UnionType.make_union([typ, NoneTyp()])
def make_required(typ: Type) -> Type:
def make_required(typ: MypyType) -> MypyType:
if not isinstance(typ, UnionType):
return typ
items = [item for item in typ.items if not isinstance(item, NoneTyp)]
@@ -232,14 +254,14 @@ def make_required(typ: Type) -> Type:
return UnionType.make_union(items)
def is_optional(typ: Type) -> bool:
def is_optional(typ: MypyType) -> bool:
if not isinstance(typ, UnionType):
return False
return any([isinstance(item, NoneTyp) for item in typ.items])
def has_any_of_bases(info: TypeInfo, bases: typing.Sequence[str]) -> bool:
def has_any_of_bases(info: TypeInfo, bases: Set[str]) -> bool:
for base_fullname in bases:
if info.has_base(base_fullname):
return True
@@ -257,10 +279,10 @@ def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]
return None
def get_assigned_value_for_class(type_info: TypeInfo, name: str) -> Optional[Expression]:
for lvalue, rvalue in iter_over_assignments(type_info.defn):
if isinstance(lvalue, NameExpr) and lvalue.name == name:
return rvalue
def get_assignment_stmt_by_name(type_info: TypeInfo, name: str) -> Optional[Expression]:
for assignment_name, call_expr in iter_over_class_level_assignments(type_info.defn):
if assignment_name == name:
return call_expr
return None
@@ -268,13 +290,13 @@ def is_field_nullable(model: TypeInfo, field_name: str) -> bool:
return metadata.get_fields_metadata(model).get(field_name, {}).get('null', False)
def is_foreign_key_like(t: Type) -> bool:
def is_foreign_key_like(t: MypyType) -> bool:
if not isinstance(t, Instance):
return False
return has_any_of_bases(t.type, (fullnames.FOREIGN_KEY_FULLNAME, fullnames.ONETOONE_FIELD_FULLNAME))
return has_any_of_bases(t.type, {fullnames.FOREIGN_KEY_FULLNAME, fullnames.ONETOONE_FIELD_FULLNAME})
def build_class_with_annotated_fields(api: 'TypeChecker', base: Type, fields: 'OrderedDict[str, Type]',
def build_class_with_annotated_fields(api: 'TypeChecker', base: MypyType, fields: 'OrderedDict[str, MypyType]',
name: str) -> Instance:
"""Build an Instance with `name` that contains the specified `fields` as attributes and extends `base`."""
# Credit: This code is largely copied/modified from TypeChecker.intersect_instance_callable and
@@ -309,7 +331,7 @@ def build_class_with_annotated_fields(api: 'TypeChecker', base: Type, fields: 'O
return Instance(info, [])
def make_named_tuple(api: 'TypeChecker', fields: 'OrderedDict[str, Type]', name: str) -> Type:
def make_named_tuple(api: 'TypeChecker', fields: 'OrderedDict[str, MypyType]', name: str) -> MypyType:
if not fields:
# No fields specified, so fallback to a subclass of NamedTuple that allows
# __getattr__ / __setattr__ for any attribute name.
@@ -317,27 +339,27 @@ def make_named_tuple(api: 'TypeChecker', fields: 'OrderedDict[str, Type]', name:
else:
fallback = build_class_with_annotated_fields(
api=api,
base=api.named_generic_type('typing.NamedTuple', []),
base=api.named_generic_type('NamedTuple', []),
fields=fields,
name=name
)
return TupleType(list(fields.values()), fallback=fallback)
def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, Type]',
required_keys: typing.Set[str]) -> TypedDictType:
def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyType]',
required_keys: Set[str]) -> TypedDictType:
object_type = api.named_generic_type('mypy_extensions._TypedDict', [])
typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type)
return typed_dict_type
def make_tuple(api: 'TypeChecker', fields: typing.List[Type]) -> TupleType:
def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType:
implicit_any = AnyType(TypeOfAny.special_form)
fallback = api.named_generic_type('builtins.tuple', [implicit_any])
return TupleType(fields, fallback=fallback)
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> Type:
def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> MypyType:
node = type_info.get(private_field_name).node
if isinstance(node, Var):
descriptor_type = node.type
@@ -347,16 +369,33 @@ def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is
return AnyType(TypeOfAny.unannotated)
def iter_over_classdefs(module_file: MypyFile) -> typing.Iterator[ClassDef]:
class IncompleteDefnException(Exception):
pass
def iter_over_toplevel_classes(module_file: MypyFile) -> Iterator[ClassDef]:
for defn in module_file.defs:
if isinstance(defn, ClassDef):
yield defn
def iter_call_assignments(klass: ClassDef) -> typing.Iterator[typing.Tuple[Lvalue, CallExpr]]:
for lvalue, rvalue in iter_over_assignments(klass):
if isinstance(rvalue, CallExpr):
yield lvalue, rvalue
def iter_call_assignments_in_class(klass: ClassDef) -> Iterator[Tuple[str, CallExpr]]:
for name, expression in iter_over_assignments_in_class(klass):
if isinstance(expression, CallExpr):
yield name, expression
def iter_over_field_inits_in_class(klass: ClassDef) -> Iterator[Tuple[str, CallExpr]]:
for lvalue, rvalue in iter_over_assignments_in_class(klass):
if isinstance(lvalue, NameExpr) and isinstance(rvalue, CallExpr):
field_name = lvalue.name
if isinstance(rvalue.callee, MemberExpr) and isinstance(rvalue.callee.node, TypeInfo):
if isinstance(rvalue.callee.node, FakeInfo):
raise IncompleteDefnException()
field_info = rvalue.callee.node
if field_info.has_base(fullnames.FIELD_FULLNAME):
yield field_name, rvalue
def get_related_manager_type_from_metadata(model_info: TypeInfo, related_manager_name: str,
@@ -394,3 +433,20 @@ def get_primary_key_field_name(model_info: TypeInfo) -> Optional[str]:
if is_primary_key:
return field_name
return None
def _get_app_models_file(app_name: str, all_modules: Dict[str, MypyFile]) -> Optional[MypyFile]:
models_module = '.'.join([app_name, 'models'])
return all_modules.get(models_module)
def get_model_info(app_name_dot_model_name: str, all_modules: Dict[str, MypyFile]) -> Optional[TypeInfo]:
""" Resolve app_name.ModelName into model fullname """
app_name, model_name = app_name_dot_model_name.split('.')
models_file = _get_app_models_file(app_name, all_modules)
if models_file is None:
return None
sym = models_file.names.get(model_name)
if sym and isinstance(sym.node, TypeInfo):
return sym.node

View File

View File

@@ -0,0 +1,21 @@
#!/usr/bin/env python
"""Django's command-line utility for administrative tasks."""
import os
import sys
def main():
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'sample_django_project.settings')
try:
from django.core.management import execute_from_command_line
except ImportError as exc:
raise ImportError(
"Couldn't import Django. Are you sure it's installed and "
"available on your PYTHONPATH environment variable? Did you "
"forget to activate a virtual environment?"
) from exc
execute_from_command_line(sys.argv)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,3 @@
from django.contrib import admin
# Register your models here.

View File

@@ -0,0 +1,5 @@
from django.apps import AppConfig
class MyappConfig(AppConfig):
label = 'myapp22'

View File

@@ -0,0 +1,6 @@
from django.db import models
# Create your models here.
class MyModel(models.Model):
pass

View File

@@ -0,0 +1,3 @@
from django.test import TestCase
# Create your tests here.

View File

@@ -0,0 +1,3 @@
from django.shortcuts import render
# Create your views here.

View File

@@ -0,0 +1,121 @@
"""
Django settings for sample_django_project project.
Generated by 'django-admin startproject' using Django 2.2.3.
For more information on this file, see
https://docs.djangoproject.com/en/2.2/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/2.2/ref/settings/
"""
import os
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/2.2/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = 'e6gj!2x(*odqwmjafrn7#35%)&rnn&^*0x-f&j0prgr--&xf+%'
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True
ALLOWED_HOSTS = []
# Application definition
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'mypy_django_plugin.lib.tests.sample_django_project.myapp'
]
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
ROOT_URLCONF = 'sample_django_project.urls'
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
]
WSGI_APPLICATION = 'sample_django_project.wsgi.application'
# Database
# https://docs.djangoproject.com/en/2.2/ref/settings/#databases
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
}
}
# Password validation
# https://docs.djangoproject.com/en/2.2/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
]
# Internationalization
# https://docs.djangoproject.com/en/2.2/topics/i18n/
LANGUAGE_CODE = 'en-us'
TIME_ZONE = 'UTC'
USE_I18N = True
USE_L10N = True
USE_TZ = True
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/2.2/howto/static-files/
STATIC_URL = '/static/'

View File

@@ -0,0 +1,21 @@
"""sample_django_project URL Configuration
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/2.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.contrib import admin
from django.urls import path
urlpatterns = [
path('admin/', admin.site.urls),
]

View File

@@ -0,0 +1,16 @@
"""
WSGI config for sample_django_project project.
It exposes the WSGI callable as a module-level variable named ``application``.
For more information on this file, see
https://docs.djangoproject.com/en/2.2/howto/deployment/wsgi/
"""
import os
from django.core.wsgi import get_wsgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'sample_django_project.settings')
application = get_wsgi_application()

View File

@@ -0,0 +1,14 @@
from mypy.options import Options
from mypy_django_plugin.lib.config import extract_app_model_aliases
from mypy_django_plugin.main import DjangoPlugin
def test_parse_django_settings():
app_model_mapping = extract_app_model_aliases('mypy_django_plugin.lib.tests.sample_django_project.root.settings')
assert app_model_mapping['myapp.MyModel'] == 'mypy_django_plugin.lib.tests.sample_django_project.myapp.models.MyModel'
def test_instantiate_plugin_with_config():
plugin = DjangoPlugin(Options())

View File

@@ -2,6 +2,7 @@ import os
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, cast
import toml
from mypy.nodes import MypyFile, NameExpr, TypeInfo
from mypy.options import Options
from mypy.plugin import (
@@ -10,7 +11,7 @@ from mypy.plugin import (
from mypy.types import AnyType, Instance, Type, TypeOfAny
from mypy_django_plugin.lib import metadata, fullnames, helpers
from mypy_django_plugin.config import Config
from mypy_django_plugin.lib.config import Config, extract_app_model_aliases
from mypy_django_plugin.transformers import fields, init_create
from mypy_django_plugin.transformers.forms import (
extract_proper_type_for_get_form, extract_proper_type_for_get_form_class, make_meta_nested_class_inherit_from_any,
@@ -31,7 +32,9 @@ from mypy_django_plugin.transformers.settings import (
)
def transform_model_class(ctx: ClassDefContext, ignore_missing_model_attributes: bool) -> None:
def transform_model_class(ctx: ClassDefContext,
ignore_missing_model_attributes: bool,
app_models_mapping: Optional[Dict[str, str]]) -> None:
try:
sym = ctx.api.lookup_fully_qualified(fullnames.MODEL_CLASS_FULLNAME)
except KeyError:
@@ -41,7 +44,7 @@ def transform_model_class(ctx: ClassDefContext, ignore_missing_model_attributes:
if sym is not None and isinstance(sym.node, TypeInfo):
metadata.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1
process_model_class(ctx, ignore_missing_model_attributes)
process_model_class(ctx, ignore_missing_model_attributes, app_models_mapping)
def transform_manager_class(ctx: ClassDefContext) -> None:
@@ -116,7 +119,7 @@ def return_type_for_id_field(ctx: AttributeContext) -> Type:
def transform_form_view(ctx: ClassDefContext) -> None:
form_class_value = helpers.get_assigned_value_for_class(ctx.cls.info, 'form_class')
form_class_value = helpers.get_assignment_stmt_by_name(ctx.cls.info, 'form_class')
if isinstance(form_class_value, NameExpr):
metadata.get_django_metadata(ctx.cls.info)['form_class'] = form_class_value.fullname
@@ -125,6 +128,17 @@ class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
django_plugin_config = None
if os.path.exists('pyproject.toml'):
with open('pyproject.toml', 'r') as f:
pyproject_toml = toml.load(f)
django_plugin_config = pyproject_toml.get('tool', {}).get('django-stubs')
if django_plugin_config and 'django_settings_module' in django_plugin_config:
self.app_models_mapping = extract_app_model_aliases(django_plugin_config['django_settings_module'])
else:
self.app_models_mapping = None
config_fpath = os.environ.get('MYPY_DJANGO_CONFIG', 'mypy_django.ini')
if config_fpath and os.path.exists(config_fpath):
self.config = Config.from_config_file(config_fpath)
@@ -195,10 +209,10 @@ class DjangoPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == 'django.contrib.auth.get_user_model':
return partial(return_user_model_hook,
settings_modules=self._get_settings_modules_in_order_of_priority())
# if fullname == 'django.contrib.auth.get_user_model':
# return partial(return_user_model_hook,
# settings_modules=self._get_settings_modules_in_order_of_priority())
#
manager_bases = self._get_current_manager_bases()
if fullname in manager_bases:
return determine_proper_manager_type
@@ -208,48 +222,48 @@ class DjangoPlugin(Plugin):
if info.has_base(fullnames.FIELD_FULLNAME):
return fields.process_field_instantiation
if metadata.get_django_metadata(info).get('generated_init'):
return init_create.redefine_and_typecheck_model_init
# if metadata.get_django_metadata(info).get('generated_init'):
# return init_create.redefine_and_typecheck_model_init
def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], Type]]:
class_name, _, method_name = fullname.rpartition('.')
if method_name == 'get_form_class':
info = self._get_typeinfo_or_none(class_name)
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return extract_proper_type_for_get_form_class
if method_name == 'get_form':
info = self._get_typeinfo_or_none(class_name)
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return extract_proper_type_for_get_form
if method_name == 'values':
model_info = self._get_typeinfo_or_none(class_name)
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return extract_proper_type_for_queryset_values
if method_name == 'values_list':
model_info = self._get_typeinfo_or_none(class_name)
if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return extract_proper_type_queryset_values_list
if fullname in {'django.apps.registry.Apps.get_model',
'django.db.migrations.state.StateApps.get_model'}:
return determine_model_cls_from_string_for_migrations
manager_classes = self._get_current_manager_bases()
class_fullname, _, method_name = fullname.rpartition('.')
if class_fullname in manager_classes and method_name == 'create':
return init_create.redefine_and_typecheck_model_create
return None
# def get_method_hook(self, fullname: str
# ) -> Optional[Callable[[MethodContext], Type]]:
# class_name, _, method_name = fullname.rpartition('.')
#
# if method_name == 'get_form_class':
# info = self._get_typeinfo_or_none(class_name)
# if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
# return extract_proper_type_for_get_form_class
#
# if method_name == 'get_form':
# info = self._get_typeinfo_or_none(class_name)
# if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
# return extract_proper_type_for_get_form
#
# if method_name == 'values':
# model_info = self._get_typeinfo_or_none(class_name)
# if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
# return extract_proper_type_for_queryset_values
#
# if method_name == 'values_list':
# model_info = self._get_typeinfo_or_none(class_name)
# if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
# return extract_proper_type_queryset_values_list
#
# if fullname in {'django.apps.registry.Apps.get_model',
# 'django.db.migrations.state.StateApps.get_model'}:
# return determine_model_cls_from_string_for_migrations
#
# manager_classes = self._get_current_manager_bases()
# class_fullname, _, method_name = fullname.rpartition('.')
# if class_fullname in manager_classes and method_name == 'create':
# return init_create.redefine_and_typecheck_model_create
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in self._get_current_model_bases():
return partial(transform_model_class,
ignore_missing_model_attributes=self.config.ignore_missing_model_attributes)
ignore_missing_model_attributes=self.config.ignore_missing_model_attributes,
app_models_mapping=self.app_models_mapping)
if fullname in self._get_current_manager_bases():
return transform_manager_class
@@ -257,20 +271,20 @@ class DjangoPlugin(Plugin):
# if fullname in self._get_current_form_bases():
# return transform_form_class
info = self._get_typeinfo_or_none(fullname)
if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return transform_form_view
# info = self._get_typeinfo_or_none(fullname)
# if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
# return transform_form_view
return None
def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
class_name, _, attr_name = fullname.rpartition('.')
if class_name == fullnames.DUMMY_SETTINGS_BASE_CLASS:
return partial(get_type_of_setting,
setting_name=attr_name,
settings_modules=self._get_settings_modules_in_order_of_priority(),
ignore_missing_settings=self.config.ignore_missing_settings)
# if class_name == fullnames.DUMMY_SETTINGS_BASE_CLASS:
# return partial(get_type_of_setting,
# setting_name=attr_name,
# settings_modules=self._get_settings_modules_in_order_of_priority(),
# ignore_missing_settings=self.config.ignore_missing_settings)
if class_name in self._get_current_model_bases():
if attr_name == 'id':

View File

@@ -1,13 +1,13 @@
from typing import Optional, cast
from mypy.checker import TypeChecker
from mypy.nodes import ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo
from mypy.nodes import ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo, Expression
from mypy.plugin import FunctionContext
from mypy.types import (
AnyType, CallableType, Instance, TupleType, Type, UnionType,
)
from mypy_django_plugin.lib import metadata, fullnames, helpers
from mypy_django_plugin.lib import fullnames, helpers, metadata
def extract_referred_to_type(ctx: FunctionContext) -> Optional[Instance]:
@@ -90,7 +90,7 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext) -> Type:
def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
default_return_type = cast(Instance, ctx.default_return_type)
is_nullable = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'null'))
is_nullable = helpers.parse_bool(helpers.get_call_argument_by_name(ctx, 'null'))
set_type = helpers.get_private_descriptor_type(default_return_type.type, '_pyi_private_set_type',
is_nullable=is_nullable)
get_type = helpers.get_private_descriptor_type(default_return_type.type, '_pyi_private_get_type',
@@ -101,7 +101,7 @@ def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance:
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
default_return_type = set_descriptor_types_for_field(ctx)
base_field_arg_type = helpers.get_argument_type_by_name(ctx, 'base_field')
base_field_arg_type = helpers.get_call_argument_type_by_name(ctx, 'base_field')
if not base_field_arg_type or not isinstance(base_field_arg_type, Instance):
return default_return_type
@@ -118,9 +118,7 @@ def transform_into_proper_return_type(ctx: FunctionContext) -> Type:
if not isinstance(default_return_type, Instance):
return default_return_type
if helpers.has_any_of_bases(default_return_type.type, (fullnames.FOREIGN_KEY_FULLNAME,
fullnames.ONETOONE_FIELD_FULLNAME,
fullnames.MANYTOMANY_FIELD_FULLNAME)):
if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):
return fill_descriptor_types_for_related_field(ctx)
if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME):
@@ -135,55 +133,99 @@ def process_field_instantiation(ctx: FunctionContext) -> Type:
return transform_into_proper_return_type(ctx)
def _parse_choices_type(ctx: FunctionContext, choices_arg: Expression) -> Optional[str]:
if isinstance(choices_arg, (TupleExpr, ListExpr)):
# iterable of 2 element tuples of two kinds
_, analyzed_choices = ctx.api.analyze_iterable_item_type(choices_arg)
if isinstance(analyzed_choices, TupleType):
first_element_type = analyzed_choices.items[0]
if isinstance(first_element_type, Instance):
return first_element_type.type.fullname()
def _parse_referenced_model(ctx: FunctionContext, to_arg: Expression) -> Optional[TypeInfo]:
if isinstance(to_arg, NameExpr) and isinstance(to_arg.node, TypeInfo):
# reference to the model class
return to_arg.node
elif isinstance(to_arg, StrExpr):
referenced_model_info = helpers.get_model_info(to_arg.value, ctx.api.modules)
if referenced_model_info is not None:
return referenced_model_info
def parse_field_init_arguments_into_model_metadata(ctx: FunctionContext) -> None:
api = cast(TypeChecker, ctx.api)
outer_model = api.scope.active_class()
outer_model = ctx.api.scope.active_class()
if outer_model is None or not outer_model.has_base(fullnames.MODEL_CLASS_FULLNAME):
# outside models.Model class, undetermined
return
field_name = None
for name_expr, stmt in helpers.iter_over_assignments(outer_model.defn):
if stmt == ctx.context and isinstance(name_expr, NameExpr):
field_name = name_expr.name
# Determine name of the current field
for attr_name, stmt in helpers.iter_over_class_level_assignments(outer_model.defn):
if stmt == ctx.context:
field_name = attr_name
break
if field_name is None:
else:
return
fields_metadata = metadata.get_fields_metadata(outer_model)
model_fields_metadata = metadata.get_fields_metadata(outer_model)
# primary key
is_primary_key = False
primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key')
primary_key_arg = helpers.get_call_argument_by_name(ctx, 'primary_key')
if primary_key_arg:
is_primary_key = helpers.parse_bool(primary_key_arg)
fields_metadata[field_name] = {'primary_key': is_primary_key}
model_fields_metadata[field_name] = {'primary_key': is_primary_key}
# choices
choices_arg = helpers.get_argument_by_name(ctx, 'choices')
if choices_arg and isinstance(choices_arg, (TupleExpr, ListExpr)):
# iterable of 2 element tuples of two kinds
_, analyzed_choices = api.analyze_iterable_item_type(choices_arg)
if isinstance(analyzed_choices, TupleType):
first_element_type = analyzed_choices.items[0]
if isinstance(first_element_type, Instance):
fields_metadata[field_name]['choices'] = first_element_type.type.fullname()
choices_arg = helpers.get_call_argument_by_name(ctx, 'choices')
if choices_arg:
choices_type_fullname = _parse_choices_type(ctx.api, choices_arg)
if choices_type_fullname:
model_fields_metadata[field_name]['choices_type'] = choices_type_fullname
# nullability
null_arg = helpers.get_argument_by_name(ctx, 'null')
null_arg = helpers.get_call_argument_by_name(ctx, 'null')
is_nullable = False
if null_arg:
is_nullable = helpers.parse_bool(null_arg)
fields_metadata[field_name]['null'] = is_nullable
model_fields_metadata[field_name]['null'] = is_nullable
# is_blankable
blank_arg = helpers.get_argument_by_name(ctx, 'blank')
blank_arg = helpers.get_call_argument_by_name(ctx, 'blank')
is_blankable = False
if blank_arg:
is_blankable = helpers.parse_bool(blank_arg)
fields_metadata[field_name]['blank'] = is_blankable
model_fields_metadata[field_name]['blank'] = is_blankable
# default
default_arg = helpers.get_argument_by_name(ctx, 'default')
default_arg = helpers.get_call_argument_by_name(ctx, 'default')
if default_arg and not helpers.is_none_expr(default_arg):
fields_metadata[field_name]['default_specified'] = True
model_fields_metadata[field_name]['default_specified'] = True
if helpers.has_any_of_bases(ctx.default_return_type.type, fullnames.RELATED_FIELDS_CLASSES):
# to
to_arg = helpers.get_call_argument_by_name(ctx, 'to')
if to_arg:
referenced_model = _parse_referenced_model(ctx, to_arg)
if referenced_model is not None:
model_fields_metadata[field_name]['to'] = referenced_model.fullname()
else:
model_fields_metadata[field_name]['to'] = to_arg.value
# referenced_model = to_arg.value
# raise helpers.IncompleteDefnException()
# model_fields_metadata[field_name]['to'] = referenced_model.fullname()
# if referenced_model is not None:
# model_fields_metadata[field_name]['to'] = referenced_model.fullname()
# else:
# assert isinstance(to_arg, StrExpr)
# model_fields_metadata[field_name]['to'] = to_arg.value
# related_name
related_name_arg = helpers.get_call_argument_by_name(ctx, 'related_name')
if related_name_arg:
if isinstance(related_name_arg, StrExpr):
model_fields_metadata[field_name]['related_name'] = related_name_arg.value
else:
model_fields_metadata[field_name]['related_name'] = outer_model.name().lower() + '_set'

View File

@@ -16,7 +16,7 @@ def extract_proper_type_for_get_form(ctx: MethodContext) -> Type:
if not isinstance(object_type, Instance):
return ctx.default_return_type
form_class_type = helpers.get_argument_type_by_name(ctx, 'form_class')
form_class_type = helpers.get_call_argument_type_by_name(ctx, 'form_class')
if form_class_type is None or isinstance(form_class_type, NoneTyp):
# extract from specified form_class in metadata
form_class_fullname = metadata.get_django_metadata(object_type.type).get('form_class', None)

View File

@@ -18,17 +18,20 @@ from mypy_django_plugin.lib import metadata, fullnames, helpers
class ModelClassInitializer(metaclass=ABCMeta):
api: SemanticAnalyzerPass2
model_classdef: ClassDef
app_models_mapping: Optional[Dict[str, str]] = None
@classmethod
def from_ctx(cls, ctx: ClassDefContext):
return cls(api=cast(SemanticAnalyzerPass2, ctx.api), model_classdef=ctx.cls)
def from_ctx(cls, ctx: ClassDefContext, app_models_mapping: Optional[Dict[str, str]]):
return cls(api=cast(SemanticAnalyzerPass2, ctx.api),
model_classdef=ctx.cls,
app_models_mapping=app_models_mapping)
def get_meta_attribute(self, name: str) -> Optional[Expression]:
meta_node = helpers.get_nested_meta_node_for_current_class(self.model_classdef.info)
if meta_node is None:
return None
return helpers.get_assigned_value_for_class(meta_node, name)
return helpers.get_assignment_stmt_by_name(meta_node, name)
def is_abstract_model(self) -> bool:
is_abstract_expr = self.get_meta_attribute('abstract')
@@ -46,29 +49,74 @@ class ModelClassInitializer(metaclass=ABCMeta):
var.is_initialized_in_class = True
self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True)
def model_has_name_defined(self, name: str) -> bool:
return name in self.model_classdef.info.names
@abstractmethod
def run(self) -> None:
raise NotImplementedError()
def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExpr, CallExpr]]:
for lvalue, rvalue in helpers.iter_call_assignments(klass):
if (isinstance(lvalue, NameExpr)
and isinstance(rvalue.callee, MemberExpr)):
if rvalue.callee.fullname in {fullnames.FOREIGN_KEY_FULLNAME,
fullnames.ONETOONE_FIELD_FULLNAME}:
yield lvalue, rvalue
for field_name, field_init in helpers.iter_over_field_inits_in_class(klass):
field_info = field_init.callee.node
assert isinstance(field_info, TypeInfo)
if helpers.has_any_of_bases(field_init.callee.node, {fullnames.FOREIGN_KEY_FULLNAME,
fullnames.ONETOONE_FIELD_FULLNAME}):
yield field_name, field_init
class SetIdAttrsForRelatedFields(ModelClassInitializer):
class AddReferencesToRelatedModels(ModelClassInitializer):
"""
For every
attr1 = models.ForeignKey(to=MyModel)
sets `attr1_id` attribute to the current model.
"""
def run(self) -> None:
for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef):
node_name = lvalue.name + '_id'
self.add_new_node_to_model_class(name=node_name,
typ=self.api.builtin_type('builtins.int'))
for field_name, field_init_expr in helpers.iter_over_field_inits_in_class(self.model_classdef):
ref_id_name = field_name + '_id'
field_info = field_init_expr.callee.node
assert isinstance(field_info, TypeInfo)
if not self.model_has_name_defined(ref_id_name):
if helpers.has_any_of_bases(field_info, {fullnames.FOREIGN_KEY_FULLNAME,
fullnames.ONETOONE_FIELD_FULLNAME}):
self.add_new_node_to_model_class(name=ref_id_name,
typ=self.api.builtin_type('builtins.int'))
# field_init_expr.callee.node
#
# for field_name, field_init_expr in helpers.iter_call_assignments_in_class(self.model_classdef):
# ref_id_name = field_name + '_id'
# if not self.model_has_name_defined(ref_id_name):
# field_class_info = field_init_expr.callee.node
# if not field_class_info:
#
# if not field_init_expr.callee.node:
#
# if isinstance(field_init_expr.callee.node, TypeInfo) \
# and helpers.has_any_of_bases(field_init_expr.callee.node,
# {fullnames.FOREIGN_KEY_FULLNAME,
# fullnames.ONETOONE_FIELD_FULLNAME}):
# self.add_new_node_to_model_class(name=ref_id_name,
# typ=self.api.builtin_type('builtins.int'))
class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):
"""
Replaces
class MyModel(models.Model):
class Meta:
pass
with
class MyModel(models.Model):
class Meta(Any):
pass
to get around incompatible Meta inner classes for different models.
"""
def run(self) -> None:
meta_node = helpers.get_nested_meta_node_for_current_class(self.model_classdef.info)
if meta_node is None:
@@ -77,24 +125,24 @@ class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):
class AddDefaultObjectsManager(ModelClassInitializer):
def add_new_manager(self, name: str, manager_type: Optional[Instance]) -> None:
def _add_new_manager(self, name: str, manager_type: Optional[Instance]) -> None:
if manager_type is None:
return None
self.add_new_node_to_model_class(name, manager_type)
def add_private_default_manager(self, manager_type: Optional[Instance]) -> None:
def _add_private_default_manager(self, manager_type: Optional[Instance]) -> None:
if manager_type is None:
return None
self.add_new_node_to_model_class('_default_manager', manager_type)
def get_existing_managers(self) -> List[Tuple[str, TypeInfo]]:
def _get_existing_managers(self) -> List[Tuple[str, TypeInfo]]:
managers = []
for base in self.model_classdef.info.mro:
for name_expr, member_expr in helpers.iter_call_assignments(base.defn):
manager_name = name_expr.name
callee_expr = member_expr.callee
for manager_name, call_expr in helpers.iter_call_assignments_in_class(base.defn):
callee_expr = call_expr.callee
if isinstance(callee_expr, IndexExpr):
callee_expr = callee_expr.analyzed.expr
if isinstance(callee_expr, (MemberExpr, NameExpr)) \
and isinstance(callee_expr.node, TypeInfo) \
and callee_expr.node.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
@@ -102,12 +150,12 @@ class AddDefaultObjectsManager(ModelClassInitializer):
return managers
def run(self) -> None:
existing_managers = self.get_existing_managers()
existing_managers = self._get_existing_managers()
if existing_managers:
first_manager_type = None
for manager_name, manager_type_info in existing_managers:
manager_type = Instance(manager_type_info, args=[Instance(self.model_classdef.info, [])])
self.add_new_manager(name=manager_name, manager_type=manager_type)
self._add_new_manager(name=manager_name, manager_type=manager_type)
if first_manager_type is None:
first_manager_type = manager_type
else:
@@ -117,33 +165,46 @@ class AddDefaultObjectsManager(ModelClassInitializer):
first_manager_type = self.api.named_type_or_none(fullnames.MANAGER_CLASS_FULLNAME,
args=[Instance(self.model_classdef.info, [])])
self.add_new_manager('objects', manager_type=first_manager_type)
self._add_new_manager('objects', manager_type=first_manager_type)
if self.is_abstract_model():
return None
default_manager_name_expr = self.get_meta_attribute('default_manager_name')
if isinstance(default_manager_name_expr, StrExpr):
self.add_private_default_manager(self.model_classdef.info.get(default_manager_name_expr.value).type)
self._add_private_default_manager(self.model_classdef.info.get(default_manager_name_expr.value).type)
else:
self.add_private_default_manager(first_manager_type)
self._add_private_default_manager(first_manager_type)
class AddIdAttributeIfPrimaryKeyTrueIsNotSet(ModelClassInitializer):
class AddDefaultPrimaryKey(ModelClassInitializer):
"""
Sets default integer `id` attribute, if:
* model is not abstract (abstract = False)
* there's no field with primary_key=True
"""
def run(self) -> None:
if self.is_abstract_model():
# no need for .id attr
# abstract models cannot be instantiated, and do not need `id` attribute
return None
for _, rvalue in helpers.iter_call_assignments(self.model_classdef):
if ('primary_key' in rvalue.arg_names
and self.api.parse_bool(rvalue.args[rvalue.arg_names.index('primary_key')])):
for _, field_init_expr in helpers.iter_over_field_inits_in_class(self.model_classdef):
if ('primary_key' in field_init_expr.arg_names
and self.api.parse_bool(field_init_expr.args[field_init_expr.arg_names.index('primary_key')])):
break
else:
self.add_new_node_to_model_class('id', self.api.builtin_type('builtins.object'))
self.add_new_node_to_model_class('id', self.api.builtin_type('builtins.int'))
def _get_to_expr(field_init_expr) -> Expression:
if 'to' in field_init_expr.arg_names:
return field_init_expr.args[field_init_expr.arg_names.index('to')]
else:
return field_init_expr.args[0]
class AddRelatedManagers(ModelClassInitializer):
def add_related_manager_variable(self, manager_name: str, related_field_type_data: Dict[str, Any]) -> None:
def _add_related_manager_variable(self, manager_name: str, related_field_type_data: Dict[str, Any]) -> None:
# add dummy related manager for use later
self.add_new_node_to_model_class(manager_name, self.api.builtin_type('builtins.object'))
@@ -153,24 +214,41 @@ class AddRelatedManagers(ModelClassInitializer):
def run(self) -> None:
for module_name, module_file in self.api.modules.items():
for model_defn in helpers.iter_over_classdefs(module_file):
if not model_defn.info:
self.api.defer()
for model_classdef in helpers.iter_over_toplevel_classes(module_file):
for field_name, field_init in helpers.iter_over_field_inits_in_class(model_classdef):
field_info = field_init.callee.node
assert isinstance(field_info, TypeInfo)
for lvalue, field_init in helpers.iter_call_assignments(model_defn):
if is_related_field(field_init, module_file):
try:
referenced_model_fullname = extract_referenced_model_fullname(field_init,
module_file=module_file,
all_modules=self.api.modules)
except helpers.SelfReference:
referenced_model_fullname = model_defn.fullname
if helpers.has_any_of_bases(field_info, fullnames.RELATED_FIELDS_CLASSES):
# try:
to_arg_expr = _get_to_expr(field_init)
if isinstance(to_arg_expr, NameExpr):
referenced_model_fullname = module_file.names[to_arg_expr.name].fullname
else:
assert isinstance(to_arg_expr, StrExpr)
value = to_arg_expr.value
if value == 'self':
# reference to the same model class
referenced_model_fullname = model_classdef.fullname
elif '.' not in value:
# reference to class in the current module
referenced_model_fullname = module_name + '.' + value
else:
referenced_model_fullname = self.app_models_mapping[value]
except helpers.SameFileModel as exc:
referenced_model_fullname = module_name + '.' + exc.model_cls_name
# referenced_model_fullname = extract_referenced_model_fullname(field_init,
# module_file=module_file,
# all_modules=self.api.modules)
# if not referenced_model_fullname:
# raise helpers.IncompleteDefnException('Cannot parse referenced model fullname')
# except helpers.SelfReference:
# referenced_model_fullname = model_classdef.fullname
#
# except helpers.SameFileModel as exc:
# referenced_model_fullname = module_name + '.' + exc.model_cls_name
if self.model_classdef.fullname == referenced_model_fullname:
related_name = model_defn.name.lower() + '_set'
if 'related_name' in field_init.arg_names:
related_name_expr = field_init.args[field_init.arg_names.index('related_name')]
if not isinstance(related_name_expr, StrExpr):
@@ -180,9 +258,10 @@ class AddRelatedManagers(ModelClassInitializer):
if related_name == '+':
# No backwards relation is desired
continue
else:
related_name = model_classdef.name.lower() + '_set'
# Default related_query_name to related_name
related_query_name = related_name
if 'related_query_name' in field_init.arg_names:
related_query_name_expr = field_init.args[field_init.arg_names.index('related_query_name')]
if isinstance(related_query_name_expr, StrExpr):
@@ -191,20 +270,24 @@ class AddRelatedManagers(ModelClassInitializer):
# not string 'related_query_name=' is not yet supported
related_query_name = None
# TODO: Handle defaulting to model name if related_name is not set
# as long as Model is not a Generic, one level depth is fine
if field_init.callee.name in {'ForeignKey', 'ManyToManyField'}:
field_type_data = {
'manager': fullnames.RELATED_MANAGER_CLASS_FULLNAME,
'of': [model_defn.info.fullname()]
}
else:
field_type_data = {
'manager': model_defn.info.fullname(),
'of': []
}
related_query_name = related_name
self.add_related_manager_variable(related_name, related_field_type_data=field_type_data)
# if helpers.has_any_of_bases(field_info, {fullnames.FOREIGN_KEY_FULLNAME,
# fullnames.MANYTOMANY_FIELD_FULLNAME}):
# # as long as Model is not a Generic, one level depth is fine
# field_type_data = {
# 'manager': fullnames.RELATED_MANAGER_CLASS_FULLNAME,
# 'of': [model_classdef.info.fullname()]
# }
# else:
# field_type_data = {
# 'manager': model_classdef.info.fullname(),
# 'of': []
# }
self.add_new_node_to_model_class(related_name, self.api.builtin_type('builtins.object'))
# self._add_related_manager_variable(related_name, related_field_type_data=field_type_data)
if related_query_name is not None:
# Only create related_query_name if it is a string literal
@@ -239,22 +322,24 @@ def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool:
return False
def extract_referenced_model_fullname(rvalue_expr: CallExpr,
def extract_referenced_model_fullname(field_init_expr: CallExpr,
module_file: MypyFile,
all_modules: Dict[str, MypyFile]) -> Optional[str]:
""" Returns fullname of a Model referenced in "to=" argument of the CallExpr"""
if 'to' in rvalue_expr.arg_names:
to_expr = rvalue_expr.args[rvalue_expr.arg_names.index('to')]
if 'to' in field_init_expr.arg_names:
to_expr = field_init_expr.args[field_init_expr.arg_names.index('to')]
else:
to_expr = rvalue_expr.args[0]
to_expr = field_init_expr.args[0]
if isinstance(to_expr, NameExpr):
return module_file.names[to_expr.name].fullname
elif isinstance(to_expr, StrExpr):
typ_fullname = helpers.get_model_fullname_from_string(to_expr.value, all_modules)
if typ_fullname is None:
return None
return typ_fullname
return None
@@ -284,16 +369,18 @@ def add_get_set_attr_fallback_to_any(ctx: ClassDefContext):
add_method(ctx, '__setattr__', [name_arg, value_arg], any)
def process_model_class(ctx: ClassDefContext, ignore_unknown_attributes: bool) -> None:
def process_model_class(ctx: ClassDefContext,
ignore_unknown_attributes: bool,
app_models_mapping: Optional[Dict[str, str]]) -> None:
initializers = [
InjectAnyAsBaseForNestedMeta,
AddDefaultPrimaryKey,
AddReferencesToRelatedModels,
AddDefaultObjectsManager,
AddIdAttributeIfPrimaryKeyTrueIsNotSet,
SetIdAttrsForRelatedFields,
AddRelatedManagers,
]
for initializer_cls in initializers:
initializer_cls.from_ctx(ctx).run()
initializer_cls.from_ctx(ctx, app_models_mapping).run()
add_dummy_init_method(ctx)

View File

@@ -110,8 +110,8 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext) -> Type:
column_names.append(None)
only_strings_as_fields_expressions = False
flat = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'flat'))
named = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'named'))
flat = helpers.parse_bool(helpers.get_call_argument_by_name(ctx, 'flat'))
named = helpers.parse_bool(helpers.get_call_argument_by_name(ctx, 'named'))
api = cast(TypeChecker, ctx.api)
if named and flat:

View File

@@ -53,7 +53,7 @@ def return_user_model_hook(ctx: FunctionContext, settings_modules: List[str]) ->
setting_module = api.modules[setting_module_name]
model_path = None
for name_expr, rvalue_expr in helpers.iter_over_assignments(setting_module):
for name_expr, rvalue_expr in helpers.iter_over_assignments_in_class(setting_module):
if isinstance(name_expr, NameExpr) and isinstance(rvalue_expr, StrExpr):
if name_expr.name == 'AUTH_USER_MODEL':
model_path = rvalue_expr.value