add proper __init__, create() support

This commit is contained in:
Maxim Kurnikov
2019-07-16 16:49:49 +03:00
parent b11a9a85f9
commit 2cb1f257eb
26 changed files with 306 additions and 463 deletions

3
.gitignore vendored
View File

@@ -8,4 +8,5 @@ out/
django-sources django-sources
build/ build/
dist/ dist/
pip-wheel-metadata/ pip-wheel-metadata/
.pytest_cache/

View File

@@ -84,6 +84,7 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
limit_choices_to: Optional[Any] = ..., limit_choices_to: Optional[Any] = ...,
ordering: Sequence[str] = ..., ordering: Sequence[str] = ...,
) -> Sequence[Union[_Choice, _ChoiceNamedGroup]]: ... ) -> Sequence[Union[_Choice, _ChoiceNamedGroup]]: ...
def has_default(self) -> bool: ...
def get_default(self) -> Any: ... def get_default(self) -> Any: ...
class IntegerField(Field[_ST, _GT]): class IntegerField(Field[_ST, _GT]):

View File

@@ -1,80 +0,0 @@
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type
from django.db.models.base import Model
from django.utils.functional import cached_property
from pytest_mypy.utils import temp_environ
if TYPE_CHECKING:
from django.apps.registry import Apps
from django.conf import LazySettings
@dataclass
class DjangoPluginConfig:
ignore_missing_settings: bool = False
ignore_missing_model_attributes: bool = False
def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']:
with temp_environ():
os.environ['DJANGO_SETTINGS_MODULE'] = settings_module
def noop_class_getitem(cls, key):
return cls
from django.db import models
models.QuerySet.__class_getitem__ = classmethod(noop_class_getitem)
models.Manager.__class_getitem__ = classmethod(noop_class_getitem)
from django.conf import settings
from django.apps import apps
apps.get_models.cache_clear()
apps.get_swappable_settings_name.cache_clear()
apps.populate(settings.INSTALLED_APPS)
assert apps.apps_ready
assert settings.configured
return apps, settings
class DjangoContext:
def __init__(self, plugin_toml_config: Optional[Dict[str, Any]]) -> None:
self.config = DjangoPluginConfig()
django_settings_module = None
if plugin_toml_config:
self.config.ignore_missing_settings = plugin_toml_config.get('ignore_missing_settings', False)
self.config.ignore_missing_model_attributes = plugin_toml_config.get('ignore_missing_model_attributes', False)
django_settings_module = plugin_toml_config.get('django_settings_module', None)
self.apps_registry: Optional[Dict[str, str]] = None
self.settings: LazySettings = None
if django_settings_module:
apps, settings = initialize_django(django_settings_module)
self.apps_registry = apps
self.settings = settings
@cached_property
def model_modules(self) -> Dict[str, List[Type[Model]]]:
""" All modules that contain Django models. """
if self.apps_registry is None:
return {}
modules: Dict[str, List[Type[Model]]] = defaultdict(list)
for model_cls in self.apps_registry.get_models():
modules[model_cls.__module__].append(model_cls)
return modules
def get_model_class_by_fullname(self, fullname: str) -> Optional[Type[Model]]:
# Returns None if Model is abstract
module, _, model_cls_name = fullname.rpartition('.')
for model_cls in self.model_modules.get(module, []):
if model_cls.__name__ == model_cls_name:
return model_cls

View File

@@ -0,0 +1,156 @@
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type
from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey
from django.utils.functional import cached_property
from mypy.checker import TypeChecker
from mypy.types import Instance, Type as MypyType
from pytest_mypy.utils import temp_environ
from django.contrib.postgres.fields import ArrayField
from django.db.models.fields import CharField, Field
from mypy_django_plugin_newsemanal.lib import helpers
if TYPE_CHECKING:
from django.apps.registry import Apps
from django.conf import LazySettings
@dataclass
class DjangoPluginConfig:
ignore_missing_settings: bool = False
ignore_missing_model_attributes: bool = False
def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']:
with temp_environ():
os.environ['DJANGO_SETTINGS_MODULE'] = settings_module
def noop_class_getitem(cls, key):
return cls
from django.db import models
models.QuerySet.__class_getitem__ = classmethod(noop_class_getitem)
models.Manager.__class_getitem__ = classmethod(noop_class_getitem)
from django.conf import settings
from django.apps import apps
apps.get_models.cache_clear()
apps.get_swappable_settings_name.cache_clear()
apps.populate(settings.INSTALLED_APPS)
assert apps.apps_ready
assert settings.configured
return apps, settings
class DjangoFieldsContext:
def get_attname(self, field: Field) -> str:
attname = field.attname
return attname
def get_field_nullability(self, field: Field, method: Optional[str]) -> bool:
nullable = field.null
if not nullable and isinstance(field, CharField) and field.blank:
return True
if method == '__init__':
if field.primary_key or isinstance(field, ForeignKey):
return True
if field.has_default():
return True
return nullable
def get_field_set_type(self, api: TypeChecker, field: Field, method: str) -> MypyType:
target_field = field
if isinstance(field, ForeignKey):
target_field = field.target_field
field_info = helpers.lookup_class_typeinfo(api, target_field.__class__)
field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
is_nullable=self.get_field_nullability(field, method))
if isinstance(target_field, ArrayField):
argument_field_type = self.get_field_set_type(api, target_field.base_field, method)
field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type)
return field_set_type
class DjangoContext:
def __init__(self, plugin_toml_config: Optional[Dict[str, Any]]) -> None:
self.config = DjangoPluginConfig()
self.fields_context = DjangoFieldsContext()
django_settings_module = None
if plugin_toml_config:
self.config.ignore_missing_settings = plugin_toml_config.get('ignore_missing_settings', False)
self.config.ignore_missing_model_attributes = plugin_toml_config.get('ignore_missing_model_attributes', False)
django_settings_module = plugin_toml_config.get('django_settings_module', None)
self.apps_registry: Optional[Dict[str, str]] = None
self.settings: LazySettings = None
if django_settings_module:
apps, settings = initialize_django(django_settings_module)
self.apps_registry = apps
self.settings = settings
@cached_property
def model_modules(self) -> Dict[str, List[Type[Model]]]:
""" All modules that contain Django models. """
if self.apps_registry is None:
return {}
modules: Dict[str, List[Type[Model]]] = defaultdict(list)
for model_cls in self.apps_registry.get_models():
modules[model_cls.__module__].append(model_cls)
return modules
def get_model_class_by_fullname(self, fullname: str) -> Optional[Type[Model]]:
# Returns None if Model is abstract
module, _, model_cls_name = fullname.rpartition('.')
for model_cls in self.model_modules.get(module, []):
if model_cls.__name__ == model_cls_name:
return model_cls
def get_model_fields(self, model_cls: Type[Model]) -> Iterator[Field]:
for field in model_cls._meta.get_fields():
if isinstance(field, Field):
yield field
def get_primary_key_field(self, model_cls: Type[Model]) -> Field:
for field in model_cls._meta.get_fields():
if isinstance(field, Field):
if field.primary_key:
return field
raise ValueError('No primary key defined')
def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], method: str) -> Dict[str, MypyType]:
expected_types = {}
if method == '__init__':
# add pk
primary_key_field = self.get_primary_key_field(model_cls)
field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method)
expected_types['pk'] = field_set_type
for field in self.get_model_fields(model_cls):
field_name = field.attname
field_set_type = self.fields_context.get_field_set_type(api, field, method)
expected_types[field_name] = field_set_type
if isinstance(field, ForeignKey):
field_name = field.name
foreign_key_info = helpers.lookup_class_typeinfo(api, field.__class__)
related_model_info = helpers.lookup_class_typeinfo(api, field.related_model)
is_nullable = self.fields_context.get_field_nullability(field, method)
foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info,
'_pyi_private_set_type',
is_nullable=is_nullable)
model_set_type = helpers.convert_any_to_type(foreign_key_set_type,
Instance(related_model_info, []))
expected_types[field_name] = model_set_type
return expected_types

View File

@@ -1,5 +1,6 @@
from typing import Dict, List, Optional, Set, Union from typing import Dict, List, Optional, Set, Union
from mypy.checker import TypeChecker
from mypy.nodes import Expression, MypyFile, NameExpr, SymbolNode, TypeInfo, Var from mypy.nodes import Expression, MypyFile, NameExpr, SymbolNode, TypeInfo, Var
from mypy.plugin import FunctionContext, MethodContext from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance, NoneTyp, Type as MypyType, TypeOfAny, UnionType from mypy.types import AnyType, Instance, NoneTyp, Type as MypyType, TypeOfAny, UnionType
@@ -23,6 +24,19 @@ def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile])
return sym.node return sym.node
def lookup_fully_qualified_typeinfo(api: TypeChecker, fullname: str) -> Optional[TypeInfo]:
node = lookup_fully_qualified_generic(fullname, api.modules)
if not isinstance(node, TypeInfo):
return None
return node
def lookup_class_typeinfo(api: TypeChecker, klass: type) -> TypeInfo:
fullname = get_class_fullname(klass)
field_info = lookup_fully_qualified_typeinfo(api, fullname)
return field_info
def reparametrize_instance(instance: Instance, new_args: List[MypyType]) -> Instance: def reparametrize_instance(instance: Instance, new_args: List[MypyType]) -> Instance:
return Instance(instance.type, args=new_args, return Instance(instance.type, args=new_args,
line=instance.line, column=instance.column) line=instance.line, column=instance.column)
@@ -97,3 +111,25 @@ def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo): if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
return metaclass_sym.node return metaclass_sym.node
return None return None
def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
if isinstance(typ, UnionType):
converted_items = []
for item in typ.items:
converted_items.append(convert_any_to_type(item, referred_to_type))
return UnionType.make_union(converted_items,
line=typ.line, column=typ.column)
if isinstance(typ, Instance):
args = []
for default_arg in typ.args:
if isinstance(default_arg, AnyType):
args.append(referred_to_type)
else:
args.append(default_arg)
return reparametrize_instance(typ, args)
if isinstance(typ, AnyType):
return referred_to_type
return typ

View File

@@ -1,21 +0,0 @@
#!/usr/bin/env python
"""Django's command-line utility for administrative tasks."""
import os
import sys
def main():
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'sample_django_project.settings')
try:
from django.core.management import execute_from_command_line
except ImportError as exc:
raise ImportError(
"Couldn't import Django. Are you sure it's installed and "
"available on your PYTHONPATH environment variable? Did you "
"forget to activate a virtual environment?"
) from exc
execute_from_command_line(sys.argv)
if __name__ == '__main__':
main()

View File

@@ -1,3 +0,0 @@
from django.contrib import admin
# Register your models here.

View File

@@ -1,5 +0,0 @@
from django.apps import AppConfig
class MyappConfig(AppConfig):
label = 'myapp22'

View File

@@ -1,6 +0,0 @@
from django.db import models
# Create your models here.
class MyModel(models.Model):
pass

View File

@@ -1,3 +0,0 @@
from django.test import TestCase
# Create your tests here.

View File

@@ -1,3 +0,0 @@
from django.shortcuts import render
# Create your views here.

View File

@@ -1,121 +0,0 @@
"""
Django settings for sample_django_project project.
Generated by 'django-admin startproject' using Django 2.2.3.
For more information on this file, see
https://docs.djangoproject.com/en/2.2/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/2.2/ref/settings/
"""
import os
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/2.2/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = 'e6gj!2x(*odqwmjafrn7#35%)&rnn&^*0x-f&j0prgr--&xf+%'
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True
ALLOWED_HOSTS = []
# Application definition
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'mypy_django_plugin.lib.tests.sample_django_project.myapp'
]
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
ROOT_URLCONF = 'sample_django_project.urls'
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
]
WSGI_APPLICATION = 'sample_django_project.wsgi.application'
# Database
# https://docs.djangoproject.com/en/2.2/ref/settings/#databases
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
}
}
# Password validation
# https://docs.djangoproject.com/en/2.2/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
]
# Internationalization
# https://docs.djangoproject.com/en/2.2/topics/i18n/
LANGUAGE_CODE = 'en-us'
TIME_ZONE = 'UTC'
USE_I18N = True
USE_L10N = True
USE_TZ = True
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/2.2/howto/static-files/
STATIC_URL = '/static/'

View File

@@ -1,21 +0,0 @@
"""sample_django_project URL Configuration
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/2.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.contrib import admin
from django.urls import path
urlpatterns = [
path('admin/', admin.site.urls),
]

View File

@@ -1,16 +0,0 @@
"""
WSGI config for sample_django_project project.
It exposes the WSGI callable as a module-level variable named ``application``.
For more information on this file, see
https://docs.djangoproject.com/en/2.2/howto/deployment/wsgi/
"""
import os
from django.core.wsgi import get_wsgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'sample_django_project.settings')
application = get_wsgi_application()

View File

@@ -1,14 +0,0 @@
from mypy.options import Options
from mypy_django_plugin.lib.config import extract_app_model_aliases
from mypy_django_plugin.main import DjangoPlugin
def test_parse_django_settings():
app_model_mapping = extract_app_model_aliases('mypy_django_plugin.lib.tests.sample_django_project.root.settings')
assert app_model_mapping['myapp.MyModel'] == 'mypy_django_plugin.lib.tests.sample_django_project.myapp.models.MyModel'
def test_instantiate_plugin_with_config():
plugin = DjangoPlugin(Options())

View File

@@ -8,7 +8,7 @@ from mypy.options import Options
from mypy.plugin import ClassDefContext, FunctionContext, Plugin, MethodContext from mypy.plugin import ClassDefContext, FunctionContext, Plugin, MethodContext
from mypy.types import Type as MypyType from mypy.types import Type as MypyType
from mypy_django_plugin_newsemanal.context import DjangoContext from mypy_django_plugin_newsemanal.django.context import DjangoContext
from mypy_django_plugin_newsemanal.lib import fullnames, metadata from mypy_django_plugin_newsemanal.lib import fullnames, metadata
from mypy_django_plugin_newsemanal.transformers import fields, settings, querysets, init_create from mypy_django_plugin_newsemanal.transformers import fields, settings, querysets, init_create
from mypy_django_plugin_newsemanal.transformers.models import process_model_class from mypy_django_plugin_newsemanal.transformers.models import process_model_class
@@ -28,7 +28,7 @@ def transform_model_class(ctx: ClassDefContext,
process_model_class(ctx, django_context) process_model_class(ctx, django_context)
def transform_manager_class(ctx: ClassDefContext) -> None: def add_new_manager_base(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME) sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo): if sym is not None and isinstance(sym.node, TypeInfo):
metadata.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1 metadata.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1
@@ -116,14 +116,15 @@ class NewSemanalDjangoPlugin(Plugin):
if info.has_base(fullnames.FIELD_FULLNAME): if info.has_base(fullnames.FIELD_FULLNAME):
return partial(fields.process_field_instantiation, django_context=self.django_context) return partial(fields.process_field_instantiation, django_context=self.django_context)
# if info.has_base(fullnames.MODEL_CLASS_FULLNAME): if info.has_base(fullnames.MODEL_CLASS_FULLNAME):
# return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context) return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)
# def get_method_hook(self, fullname: str def get_method_hook(self, fullname: str
# ) -> Optional[Callable[[MethodContext], Type]]: ) -> Optional[Callable[[MethodContext], Type]]:
# class_name, _, method_name = fullname.rpartition('.') manager_classes = self._get_current_manager_bases()
# class_fullname, _, method_name = fullname.rpartition('.')
# if class_fullname in manager_classes and method_name == 'create':
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
def get_base_class_hook(self, fullname: str def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]: ) -> Optional[Callable[[ClassDefContext], None]]:
@@ -131,7 +132,7 @@ class NewSemanalDjangoPlugin(Plugin):
return partial(transform_model_class, django_context=self.django_context) return partial(transform_model_class, django_context=self.django_context)
if fullname in self._get_current_manager_bases(): if fullname in self._get_current_manager_bases():
return transform_manager_class return add_new_manager_base
# def get_attribute_hook(self, fullname: str # def get_attribute_hook(self, fullname: str
# ) -> Optional[Callable[[AttributeContext], MypyType]]: # ) -> Optional[Callable[[AttributeContext], MypyType]]:

View File

@@ -1,12 +1,12 @@
from typing import Optional, Tuple, cast from typing import Optional, Tuple, cast
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.nodes import Expression, ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo from mypy.nodes import StrExpr, TypeInfo
from mypy.plugin import FunctionContext from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, TupleType, Type as MypyType, UnionType from mypy.types import AnyType, CallableType, Instance, Type as MypyType, UnionType
from mypy_django_plugin_newsemanal.context import DjangoContext from mypy_django_plugin_newsemanal.django.context import DjangoContext
from mypy_django_plugin_newsemanal.lib import fullnames, helpers, metadata from mypy_django_plugin_newsemanal.lib import fullnames, helpers
def extract_referred_to_type(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Instance]: def extract_referred_to_type(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Instance]:

View File

@@ -1,35 +1,69 @@
from typing import cast from typing import List, Tuple, Type, Union
from mypy.checker import TypeChecker from django.db.models.base import Model
from mypy.nodes import Argument, Var, ARG_NAMED from mypy.plugin import FunctionContext, MethodContext
from mypy.plugin import FunctionContext from mypy.types import Instance, Type as MypyType
from mypy.types import Type as MypyType, Instance
from mypy_django_plugin_newsemanal.context import DjangoContext from mypy_django_plugin_newsemanal.django.context import DjangoContext
from mypy_django_plugin_newsemanal.lib import helpers
def get_actual_types(ctx: Union[MethodContext, FunctionContext], expected_keys: List[str]) -> List[Tuple[str, MypyType]]:
actual_types = []
# positionals
for pos, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
if actual_name is None:
if ctx.callee_arg_names[0] == 'kwargs':
# unpacked dict as kwargs is not supported
continue
actual_name = expected_keys[pos]
actual_types.append((actual_name, actual_type))
# kwargs
if len(ctx.callee_arg_names) > 1:
for actual_name, actual_type in zip(ctx.arg_names[1], ctx.arg_types[1]):
if actual_name is None:
# unpacked dict as kwargs is not supported
continue
actual_types.append((actual_name, actual_type))
return actual_types
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
model_cls: Type[Model], method: str) -> MypyType:
expected_types = django_context.get_expected_types(ctx.api, model_cls, method)
expected_keys = [key for key in expected_types.keys() if key != 'pk']
for actual_name, actual_type in get_actual_types(ctx, expected_keys):
if actual_name not in expected_types:
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model_cls.__name__),
ctx.context)
continue
ctx.api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model_cls.__name__),
'got', 'expected')
return ctx.default_return_type
def redefine_and_typecheck_model_init(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: def redefine_and_typecheck_model_init(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
assert isinstance(ctx.default_return_type, Instance) assert isinstance(ctx.default_return_type, Instance)
api = cast(TypeChecker, ctx.api) model_fullname = ctx.default_return_type.type.fullname()
model_cls = django_context.get_model_class_by_fullname(model_fullname)
if model_cls is None:
return ctx.default_return_type
model_info = ctx.default_return_type.type return typecheck_model_method(ctx, django_context, model_cls, '__init__')
model_cls = django_context.get_model_class_by_fullname(model_info.fullname())
# expected_types = {}
# for field in model_cls._meta.get_fields():
# field_fullname = helpers.get_class_fullname(field.__class__)
# field_info = api.lookup_typeinfo(field_fullname)
# field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
# is_nullable=False)
# field_kwarg = Argument(variable=Var(field.attname, field_set_type),
# type_annotation=field_set_type,
# initializer=None,
# kind=ARG_NAMED)
# expected_types[field.attname] = field_set_type
# for field_name, field in model_cls._meta.fields_map.items():
# print()
# print() def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
return ctx.default_return_type isinstance(ctx.default_return_type, Instance)
model_fullname = ctx.default_return_type.type.fullname()
model_cls = django_context.get_model_class_by_fullname(model_fullname)
if model_cls is None:
return ctx.default_return_type
return typecheck_model_method(ctx, django_context, model_cls, 'create')

View File

@@ -1,18 +1,15 @@
import dataclasses import dataclasses
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Optional, Type, cast from typing import cast
from django.db.models.base import Model
from django.db.models.fields.related import ForeignKey from django.db.models.fields.related import ForeignKey
from mypy.newsemanal.semanal import NewSemanticAnalyzer from mypy.newsemanal.semanal import NewSemanticAnalyzer
from mypy.nodes import ARG_NAMED_OPT, Argument, ClassDef, MDEF, SymbolTableNode, TypeInfo, Var from mypy.nodes import ClassDef, MDEF, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext from mypy.plugin import ClassDefContext
from mypy.plugins import common from mypy.types import Instance
from mypy.types import AnyType, Instance, NoneType, Type as MypyType, UnionType
from django.contrib.postgres.fields import ArrayField from django.db.models.fields import Field
from django.db.models.fields import CharField, Field from mypy_django_plugin_newsemanal.django.context import DjangoContext
from mypy_django_plugin_newsemanal.context import DjangoContext
from mypy_django_plugin_newsemanal.lib import helpers from mypy_django_plugin_newsemanal.lib import helpers
from mypy_django_plugin_newsemanal.transformers import fields from mypy_django_plugin_newsemanal.transformers import fields
from mypy_django_plugin_newsemanal.transformers.fields import get_field_descriptor_types from mypy_django_plugin_newsemanal.transformers.fields import get_field_descriptor_types
@@ -52,101 +49,6 @@ class ModelClassInitializer(metaclass=ABCMeta):
var.is_initialized_in_class = True var.is_initialized_in_class = True
var.is_inferred = True var.is_inferred = True
self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True) self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True)
# assert self.model_classdef.info == self.api.type
# self.api.add_symbol_table_node(name, SymbolTableNode(MDEF, var, plugin_generated=True))
def convert_any_to_type(self, typ: MypyType, referred_to_type: MypyType) -> MypyType:
if isinstance(typ, UnionType):
converted_items = []
for item in typ.items:
converted_items.append(self.convert_any_to_type(item, referred_to_type))
return UnionType.make_union(converted_items,
line=typ.line, column=typ.column)
if isinstance(typ, Instance):
args = []
for default_arg in typ.args:
if isinstance(default_arg, AnyType):
args.append(referred_to_type)
else:
args.append(default_arg)
return helpers.reparametrize_instance(typ, args)
if isinstance(typ, AnyType):
return referred_to_type
return typ
def get_field_set_type(self, field: Field, method: str) -> MypyType:
target_field = field
if isinstance(field, ForeignKey):
target_field = field.target_field
field_fullname = helpers.get_class_fullname(target_field.__class__)
field_info = self.lookup_typeinfo_or_incomplete_defn_error(field_fullname)
field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
is_nullable=self.get_field_nullability(field, method))
if isinstance(target_field, ArrayField):
argument_field_type = self.get_field_set_type(target_field.base_field, method)
field_set_type = self.convert_any_to_type(field_set_type, argument_field_type)
return field_set_type
def get_field_nullability(self, field: Field, method: Optional[str]) -> bool:
nullable = field.null
if not nullable and isinstance(field, CharField) and field.blank:
return True
if method == '__init__':
if field.primary_key or isinstance(field, ForeignKey):
return True
return nullable
def get_field_kind(self, field: Field, method: str):
if method == '__init__':
# all arguments are optional in __init__
return ARG_NAMED_OPT
def get_primary_key_field(self, model_cls: Type[Model]) -> Field:
for field in model_cls._meta.get_fields():
if isinstance(field, Field):
if field.primary_key:
return field
raise ValueError('No primary key defined')
def make_field_kwarg(self, name: str, field: Field, method: str) -> Argument:
field_set_type = self.get_field_set_type(field, method)
kind = self.get_field_kind(field, method)
field_kwarg = Argument(variable=Var(name, field_set_type),
type_annotation=field_set_type,
initializer=None,
kind=kind)
return field_kwarg
def get_field_kwargs(self, model_cls: Type[Model], method: str):
field_kwargs = []
if method == '__init__':
# add primary key `pk`
primary_key_field = self.get_primary_key_field(model_cls)
field_kwarg = self.make_field_kwarg('pk', primary_key_field, method)
field_kwargs.append(field_kwarg)
for field in model_cls._meta.get_fields():
if isinstance(field, Field):
field_kwarg = self.make_field_kwarg(field.attname, field, method)
field_kwargs.append(field_kwarg)
if isinstance(field, ForeignKey):
attname = field.name
related_model_fullname = helpers.get_class_fullname(field.related_model)
model_info = self.lookup_typeinfo_or_incomplete_defn_error(related_model_fullname)
is_nullable = self.get_field_nullability(field, method)
field_set_type = Instance(model_info, [])
if is_nullable:
field_set_type = helpers.make_optional(field_set_type)
kind = self.get_field_kind(field, method)
field_kwarg = Argument(variable=Var(attname, field_set_type),
type_annotation=field_set_type,
initializer=None,
kind=kind)
field_kwargs.append(field_kwarg)
return field_kwargs
@abstractmethod @abstractmethod
def run(self) -> None: def run(self) -> None:
@@ -198,9 +100,9 @@ class AddRelatedModelsId(ModelClassInitializer):
for field in model_cls._meta.get_fields(): for field in model_cls._meta.get_fields():
if isinstance(field, ForeignKey): if isinstance(field, ForeignKey):
rel_primary_key_field = self.get_primary_key_field(field.related_model) rel_primary_key_field = self.django_context.get_primary_key_field(field.related_model)
field_info = self.lookup_field_typeinfo_or_incomplete_defn_error(rel_primary_key_field) field_info = self.lookup_field_typeinfo_or_incomplete_defn_error(rel_primary_key_field)
is_nullable = self.get_field_nullability(field, None) is_nullable = self.django_context.fields_context.get_field_nullability(field, None)
set_type, get_type = get_field_descriptor_types(field_info, is_nullable) set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
self.add_new_node_to_model_class(field.attname, self.add_new_node_to_model_class(field.attname,
Instance(field_info, [set_type, get_type])) Instance(field_info, [set_type, get_type]))
@@ -228,16 +130,6 @@ class AddManagers(ModelClassInitializer):
self.add_new_node_to_model_class('_default_manager', default_manager) self.add_new_node_to_model_class('_default_manager', default_manager)
class AddInitMethod(ModelClassInitializer):
def run(self):
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.info.fullname())
if model_cls is None:
return
field_kwargs = self.get_field_kwargs(model_cls, '__init__')
common.add_method(self.ctx, '__init__', field_kwargs, NoneType())
def process_model_class(ctx: ClassDefContext, def process_model_class(ctx: ClassDefContext,
django_context: DjangoContext) -> None: django_context: DjangoContext) -> None:
initializers = [ initializers = [
@@ -245,11 +137,10 @@ def process_model_class(ctx: ClassDefContext,
AddDefaultPrimaryKey, AddDefaultPrimaryKey,
AddRelatedModelsId, AddRelatedModelsId,
AddManagers, AddManagers,
AddInitMethod
] ]
for initializer_cls in initializers: for initializer_cls in initializers:
try: try:
initializer_cls.from_ctx(ctx, django_context).run() initializer_cls.from_ctx(ctx, django_context).run()
except helpers.IncompleteDefnException: except helpers.IncompleteDefnException:
if not ctx.api.final_iteration: if not ctx.api.final_iteration:
ctx.api.defer() ctx.api.defer()

View File

@@ -2,7 +2,7 @@ from mypy.nodes import TypeInfo
from mypy.plugin import FunctionContext from mypy.plugin import FunctionContext
from mypy.types import Type as MypyType, TypeType, Instance from mypy.types import Type as MypyType, TypeType, Instance
from mypy_django_plugin_newsemanal.context import DjangoContext from mypy_django_plugin_newsemanal.django.context import DjangoContext
from mypy_django_plugin_newsemanal.lib import helpers from mypy_django_plugin_newsemanal.lib import helpers

View File

@@ -58,9 +58,9 @@
- case: optional_id_fields_for_create_is_error - case: optional_id_fields_for_create_is_error
main: | main: |
from myapp.models import Publisher, Book from myapp.models import Publisher, Book
Book.objects.create(id=None) # E: Incompatible type for "id" of "MyModel" (got "None", expected "int") Book.objects.create(id=None) # E: Incompatible type for "id" of "Book" (got "None", expected "Union[Combinable, int, str]")
Book.objects.create(publisher=None) # E: Incompatible type for "id" of "MyModel" (got "None", expected "int") Book.objects.create(publisher=None) # E: Incompatible type for "publisher" of "Book" (got "None", expected "Union[Publisher, Combinable]")
Book.objects.create(publisher_id=None) # E: Incompatible type for "id" of "MyModel" (got "None", expected "int") Book.objects.create(publisher_id=None) # E: Incompatible type for "publisher_id" of "Book" (got "None", expected "Union[Combinable, int, str]")
installed_apps: installed_apps:
- myapp - myapp
files: files:
@@ -80,7 +80,7 @@
MyModel.objects.create(id=None) MyModel.objects.create(id=None)
from myapp.models import MyModel2 from myapp.models import MyModel2
MyModel2(id=None) # E: Incompatible type for "id" of "MyModel2" (got "None", expected "Union[float, int, str, Combinable]") MyModel2(id=None)
MyModel2.objects.create(id=None) # E: Incompatible type for "id" of "MyModel2" (got "None", expected "Union[float, int, str, Combinable]") MyModel2.objects.create(id=None) # E: Incompatible type for "id" of "MyModel2" (got "None", expected "Union[float, int, str, Combinable]")
installed_apps: installed_apps:
- myapp - myapp
@@ -94,4 +94,4 @@
class MyModel(models.Model): class MyModel(models.Model):
id = models.IntegerField(primary_key=True, default=return_int) id = models.IntegerField(primary_key=True, default=return_int)
class MyModel2(models.Model): class MyModel2(models.Model):
id = models.IntegerField(primary_key=True, default=None) id = models.IntegerField(primary_key=True)

View File

@@ -3,8 +3,8 @@
from myapp.models import MyUser from myapp.models import MyUser
user = MyUser(name=1, age=12) user = MyUser(name=1, age=12)
out: | out: |
main:2: error: Unexpected keyword argument "name" for "MyUser" main:2: error: Unexpected attribute "name" for model "MyUser"
main:2: error: Unexpected keyword argument "age" for "MyUser" main:2: error: Unexpected attribute "age" for model "MyUser"
installed_apps: installed_apps:
- myapp - myapp
files: files:
@@ -16,12 +16,28 @@
class MyUser(models.Model): class MyUser(models.Model):
pass pass
- case: plain_function_which_returns_model
main: |
from myapp.models import MyUser
def func(i: int) -> MyUser:
pass
func("hello") # E: Argument 1 to "func" has incompatible type "str"; expected "int"
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyUser(models.Model):
pass
- case: arguments_to_init_from_class_incompatible_type - case: arguments_to_init_from_class_incompatible_type
main: | main: |
from myapp.models import MyUser from myapp.models import MyUser
user = MyUser(name='hello', age=[]) user = MyUser(name='hello', age=[])
out: | out: |
main:2: error: Argument "age" to "MyUser" has incompatible type "List[<nothing>]"; expected "Union[float, int, str, Combinable]" main:2: error: Incompatible type for "age" of "MyUser" (got "List[Any]", expected "Union[float, int, str, Combinable]")
installed_apps: installed_apps:
- myapp - myapp
files: files:
@@ -91,7 +107,7 @@
- case: typechecking_of_pk - case: typechecking_of_pk
main: | main: |
from myapp.models import MyUser1 from myapp.models import MyUser1
user = MyUser1(pk=[]) # E: Argument "pk" to "MyUser1" has incompatible type "List[<nothing>]"; expected "Union[float, int, str, Combinable, None]" user = MyUser1(pk=[]) # E: Incompatible type for "pk" of "MyUser1" (got "List[Any]", expected "Union[float, int, str, Combinable, None]")
installed_apps: installed_apps:
- myapp - myapp
files: files:
@@ -110,8 +126,8 @@
from myapp.models import Publisher, PublisherDatetime, Book from myapp.models import Publisher, PublisherDatetime, Book
Book(publisher_id=1, publisher_dt_id=now) Book(publisher_id=1, publisher_dt_id=now)
Book(publisher_id=[], publisher_dt_id=now) # E: Argument "publisher_id" to "Book" has incompatible type "List[<nothing>]"; expected "Union[Combinable, int, str, None]" Book(publisher_id=[], publisher_dt_id=now) # E: Incompatible type for "publisher_id" of "Book" (got "List[Any]", expected "Union[Combinable, int, str, None]")
Book(publisher_id=1, publisher_dt_id=1) # E: Argument "publisher_dt_id" to "Book" has incompatible type "int"; expected "Union[str, date, Combinable, None]" Book(publisher_id=1, publisher_dt_id=1) # E: Incompatible type for "publisher_dt_id" of "Book" (got "int", expected "Union[str, date, Combinable, None]")
installed_apps: installed_apps:
- myapp - myapp
files: files:
@@ -139,7 +155,7 @@
class NotAValid: class NotAValid:
pass pass
array_val3: List[NotAValid] = [NotAValid()] array_val3: List[NotAValid] = [NotAValid()]
MyModel(array=array_val3) # E: Argument "array" to "MyModel" has incompatible type "List[NotAValid]"; expected "Union[Sequence[Union[float, int, str, Combinable]], Combinable]" MyModel(array=array_val3) # E: Incompatible type for "array" of "MyModel" (got "List[NotAValid]", expected "Union[Sequence[Union[float, int, str, Combinable]], Combinable]")
installed_apps: installed_apps:
- myapp - myapp
files: files: