mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-06 12:14:28 +08:00
add proper __init__, create() support
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -8,4 +8,5 @@ out/
|
||||
django-sources
|
||||
build/
|
||||
dist/
|
||||
pip-wheel-metadata/
|
||||
pip-wheel-metadata/
|
||||
.pytest_cache/
|
||||
@@ -84,6 +84,7 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
|
||||
limit_choices_to: Optional[Any] = ...,
|
||||
ordering: Sequence[str] = ...,
|
||||
) -> Sequence[Union[_Choice, _ChoiceNamedGroup]]: ...
|
||||
def has_default(self) -> bool: ...
|
||||
def get_default(self) -> Any: ...
|
||||
|
||||
class IntegerField(Field[_ST, _GT]):
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type
|
||||
|
||||
from django.db.models.base import Model
|
||||
from django.utils.functional import cached_property
|
||||
from pytest_mypy.utils import temp_environ
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.apps.registry import Apps
|
||||
from django.conf import LazySettings
|
||||
|
||||
|
||||
@dataclass
|
||||
class DjangoPluginConfig:
|
||||
ignore_missing_settings: bool = False
|
||||
ignore_missing_model_attributes: bool = False
|
||||
|
||||
|
||||
def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']:
|
||||
with temp_environ():
|
||||
os.environ['DJANGO_SETTINGS_MODULE'] = settings_module
|
||||
|
||||
def noop_class_getitem(cls, key):
|
||||
return cls
|
||||
|
||||
from django.db import models
|
||||
|
||||
models.QuerySet.__class_getitem__ = classmethod(noop_class_getitem)
|
||||
models.Manager.__class_getitem__ = classmethod(noop_class_getitem)
|
||||
|
||||
from django.conf import settings
|
||||
from django.apps import apps
|
||||
|
||||
apps.get_models.cache_clear()
|
||||
apps.get_swappable_settings_name.cache_clear()
|
||||
|
||||
apps.populate(settings.INSTALLED_APPS)
|
||||
|
||||
assert apps.apps_ready
|
||||
assert settings.configured
|
||||
|
||||
return apps, settings
|
||||
|
||||
|
||||
class DjangoContext:
|
||||
def __init__(self, plugin_toml_config: Optional[Dict[str, Any]]) -> None:
|
||||
self.config = DjangoPluginConfig()
|
||||
|
||||
django_settings_module = None
|
||||
if plugin_toml_config:
|
||||
self.config.ignore_missing_settings = plugin_toml_config.get('ignore_missing_settings', False)
|
||||
self.config.ignore_missing_model_attributes = plugin_toml_config.get('ignore_missing_model_attributes', False)
|
||||
django_settings_module = plugin_toml_config.get('django_settings_module', None)
|
||||
|
||||
self.apps_registry: Optional[Dict[str, str]] = None
|
||||
self.settings: LazySettings = None
|
||||
if django_settings_module:
|
||||
apps, settings = initialize_django(django_settings_module)
|
||||
self.apps_registry = apps
|
||||
self.settings = settings
|
||||
|
||||
@cached_property
|
||||
def model_modules(self) -> Dict[str, List[Type[Model]]]:
|
||||
""" All modules that contain Django models. """
|
||||
if self.apps_registry is None:
|
||||
return {}
|
||||
|
||||
modules: Dict[str, List[Type[Model]]] = defaultdict(list)
|
||||
for model_cls in self.apps_registry.get_models():
|
||||
modules[model_cls.__module__].append(model_cls)
|
||||
return modules
|
||||
|
||||
def get_model_class_by_fullname(self, fullname: str) -> Optional[Type[Model]]:
|
||||
# Returns None if Model is abstract
|
||||
module, _, model_cls_name = fullname.rpartition('.')
|
||||
for model_cls in self.model_modules.get(module, []):
|
||||
if model_cls.__name__ == model_cls_name:
|
||||
return model_cls
|
||||
156
mypy_django_plugin_newsemanal/django/context.py
Normal file
156
mypy_django_plugin_newsemanal/django/context.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING, Tuple, Type
|
||||
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.fields.related import ForeignKey
|
||||
from django.utils.functional import cached_property
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.types import Instance, Type as MypyType
|
||||
from pytest_mypy.utils import temp_environ
|
||||
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.db.models.fields import CharField, Field
|
||||
from mypy_django_plugin_newsemanal.lib import helpers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.apps.registry import Apps
|
||||
from django.conf import LazySettings
|
||||
|
||||
|
||||
@dataclass
|
||||
class DjangoPluginConfig:
|
||||
ignore_missing_settings: bool = False
|
||||
ignore_missing_model_attributes: bool = False
|
||||
|
||||
|
||||
def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']:
|
||||
with temp_environ():
|
||||
os.environ['DJANGO_SETTINGS_MODULE'] = settings_module
|
||||
|
||||
def noop_class_getitem(cls, key):
|
||||
return cls
|
||||
|
||||
from django.db import models
|
||||
|
||||
models.QuerySet.__class_getitem__ = classmethod(noop_class_getitem)
|
||||
models.Manager.__class_getitem__ = classmethod(noop_class_getitem)
|
||||
|
||||
from django.conf import settings
|
||||
from django.apps import apps
|
||||
|
||||
apps.get_models.cache_clear()
|
||||
apps.get_swappable_settings_name.cache_clear()
|
||||
|
||||
apps.populate(settings.INSTALLED_APPS)
|
||||
|
||||
assert apps.apps_ready
|
||||
assert settings.configured
|
||||
|
||||
return apps, settings
|
||||
|
||||
|
||||
class DjangoFieldsContext:
|
||||
def get_attname(self, field: Field) -> str:
|
||||
attname = field.attname
|
||||
return attname
|
||||
|
||||
def get_field_nullability(self, field: Field, method: Optional[str]) -> bool:
|
||||
nullable = field.null
|
||||
if not nullable and isinstance(field, CharField) and field.blank:
|
||||
return True
|
||||
if method == '__init__':
|
||||
if field.primary_key or isinstance(field, ForeignKey):
|
||||
return True
|
||||
if field.has_default():
|
||||
return True
|
||||
return nullable
|
||||
|
||||
def get_field_set_type(self, api: TypeChecker, field: Field, method: str) -> MypyType:
|
||||
target_field = field
|
||||
if isinstance(field, ForeignKey):
|
||||
target_field = field.target_field
|
||||
|
||||
field_info = helpers.lookup_class_typeinfo(api, target_field.__class__)
|
||||
field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
|
||||
is_nullable=self.get_field_nullability(field, method))
|
||||
if isinstance(target_field, ArrayField):
|
||||
argument_field_type = self.get_field_set_type(api, target_field.base_field, method)
|
||||
field_set_type = helpers.convert_any_to_type(field_set_type, argument_field_type)
|
||||
return field_set_type
|
||||
|
||||
|
||||
class DjangoContext:
|
||||
def __init__(self, plugin_toml_config: Optional[Dict[str, Any]]) -> None:
|
||||
self.config = DjangoPluginConfig()
|
||||
self.fields_context = DjangoFieldsContext()
|
||||
|
||||
django_settings_module = None
|
||||
if plugin_toml_config:
|
||||
self.config.ignore_missing_settings = plugin_toml_config.get('ignore_missing_settings', False)
|
||||
self.config.ignore_missing_model_attributes = plugin_toml_config.get('ignore_missing_model_attributes', False)
|
||||
django_settings_module = plugin_toml_config.get('django_settings_module', None)
|
||||
|
||||
self.apps_registry: Optional[Dict[str, str]] = None
|
||||
self.settings: LazySettings = None
|
||||
if django_settings_module:
|
||||
apps, settings = initialize_django(django_settings_module)
|
||||
self.apps_registry = apps
|
||||
self.settings = settings
|
||||
|
||||
@cached_property
|
||||
def model_modules(self) -> Dict[str, List[Type[Model]]]:
|
||||
""" All modules that contain Django models. """
|
||||
if self.apps_registry is None:
|
||||
return {}
|
||||
|
||||
modules: Dict[str, List[Type[Model]]] = defaultdict(list)
|
||||
for model_cls in self.apps_registry.get_models():
|
||||
modules[model_cls.__module__].append(model_cls)
|
||||
return modules
|
||||
|
||||
def get_model_class_by_fullname(self, fullname: str) -> Optional[Type[Model]]:
|
||||
# Returns None if Model is abstract
|
||||
module, _, model_cls_name = fullname.rpartition('.')
|
||||
for model_cls in self.model_modules.get(module, []):
|
||||
if model_cls.__name__ == model_cls_name:
|
||||
return model_cls
|
||||
|
||||
def get_model_fields(self, model_cls: Type[Model]) -> Iterator[Field]:
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, Field):
|
||||
yield field
|
||||
|
||||
def get_primary_key_field(self, model_cls: Type[Model]) -> Field:
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, Field):
|
||||
if field.primary_key:
|
||||
return field
|
||||
raise ValueError('No primary key defined')
|
||||
|
||||
def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], method: str) -> Dict[str, MypyType]:
|
||||
expected_types = {}
|
||||
if method == '__init__':
|
||||
# add pk
|
||||
primary_key_field = self.get_primary_key_field(model_cls)
|
||||
field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method)
|
||||
expected_types['pk'] = field_set_type
|
||||
|
||||
for field in self.get_model_fields(model_cls):
|
||||
field_name = field.attname
|
||||
field_set_type = self.fields_context.get_field_set_type(api, field, method)
|
||||
expected_types[field_name] = field_set_type
|
||||
|
||||
if isinstance(field, ForeignKey):
|
||||
field_name = field.name
|
||||
foreign_key_info = helpers.lookup_class_typeinfo(api, field.__class__)
|
||||
related_model_info = helpers.lookup_class_typeinfo(api, field.related_model)
|
||||
is_nullable = self.fields_context.get_field_nullability(field, method)
|
||||
foreign_key_set_type = helpers.get_private_descriptor_type(foreign_key_info,
|
||||
'_pyi_private_set_type',
|
||||
is_nullable=is_nullable)
|
||||
model_set_type = helpers.convert_any_to_type(foreign_key_set_type,
|
||||
Instance(related_model_info, []))
|
||||
expected_types[field_name] = model_set_type
|
||||
return expected_types
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import Expression, MypyFile, NameExpr, SymbolNode, TypeInfo, Var
|
||||
from mypy.plugin import FunctionContext, MethodContext
|
||||
from mypy.types import AnyType, Instance, NoneTyp, Type as MypyType, TypeOfAny, UnionType
|
||||
@@ -23,6 +24,19 @@ def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile])
|
||||
return sym.node
|
||||
|
||||
|
||||
def lookup_fully_qualified_typeinfo(api: TypeChecker, fullname: str) -> Optional[TypeInfo]:
|
||||
node = lookup_fully_qualified_generic(fullname, api.modules)
|
||||
if not isinstance(node, TypeInfo):
|
||||
return None
|
||||
return node
|
||||
|
||||
|
||||
def lookup_class_typeinfo(api: TypeChecker, klass: type) -> TypeInfo:
|
||||
fullname = get_class_fullname(klass)
|
||||
field_info = lookup_fully_qualified_typeinfo(api, fullname)
|
||||
return field_info
|
||||
|
||||
|
||||
def reparametrize_instance(instance: Instance, new_args: List[MypyType]) -> Instance:
|
||||
return Instance(instance.type, args=new_args,
|
||||
line=instance.line, column=instance.column)
|
||||
@@ -97,3 +111,25 @@ def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]
|
||||
if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo):
|
||||
return metaclass_sym.node
|
||||
return None
|
||||
|
||||
|
||||
def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType:
|
||||
if isinstance(typ, UnionType):
|
||||
converted_items = []
|
||||
for item in typ.items:
|
||||
converted_items.append(convert_any_to_type(item, referred_to_type))
|
||||
return UnionType.make_union(converted_items,
|
||||
line=typ.line, column=typ.column)
|
||||
if isinstance(typ, Instance):
|
||||
args = []
|
||||
for default_arg in typ.args:
|
||||
if isinstance(default_arg, AnyType):
|
||||
args.append(referred_to_type)
|
||||
else:
|
||||
args.append(default_arg)
|
||||
return reparametrize_instance(typ, args)
|
||||
|
||||
if isinstance(typ, AnyType):
|
||||
return referred_to_type
|
||||
|
||||
return typ
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Django's command-line utility for administrative tasks."""
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'sample_django_project.settings')
|
||||
try:
|
||||
from django.core.management import execute_from_command_line
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Couldn't import Django. Are you sure it's installed and "
|
||||
"available on your PYTHONPATH environment variable? Did you "
|
||||
"forget to activate a virtual environment?"
|
||||
) from exc
|
||||
execute_from_command_line(sys.argv)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,3 +0,0 @@
|
||||
from django.contrib import admin
|
||||
|
||||
# Register your models here.
|
||||
@@ -1,5 +0,0 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class MyappConfig(AppConfig):
|
||||
label = 'myapp22'
|
||||
@@ -1,6 +0,0 @@
|
||||
from django.db import models
|
||||
|
||||
|
||||
# Create your models here.
|
||||
class MyModel(models.Model):
|
||||
pass
|
||||
@@ -1,3 +0,0 @@
|
||||
from django.test import TestCase
|
||||
|
||||
# Create your tests here.
|
||||
@@ -1,3 +0,0 @@
|
||||
from django.shortcuts import render
|
||||
|
||||
# Create your views here.
|
||||
@@ -1,121 +0,0 @@
|
||||
"""
|
||||
Django settings for sample_django_project project.
|
||||
|
||||
Generated by 'django-admin startproject' using Django 2.2.3.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/2.2/topics/settings/
|
||||
|
||||
For the full list of settings and their values, see
|
||||
https://docs.djangoproject.com/en/2.2/ref/settings/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
# Quick-start development settings - unsuitable for production
|
||||
# See https://docs.djangoproject.com/en/2.2/howto/deployment/checklist/
|
||||
|
||||
# SECURITY WARNING: keep the secret key used in production secret!
|
||||
SECRET_KEY = 'e6gj!2x(*odqwmjafrn7#35%)&rnn&^*0x-f&j0prgr--&xf+%'
|
||||
|
||||
# SECURITY WARNING: don't run with debug turned on in production!
|
||||
DEBUG = True
|
||||
|
||||
ALLOWED_HOSTS = []
|
||||
|
||||
|
||||
# Application definition
|
||||
|
||||
INSTALLED_APPS = [
|
||||
'django.contrib.admin',
|
||||
'django.contrib.auth',
|
||||
'django.contrib.contenttypes',
|
||||
'django.contrib.sessions',
|
||||
'django.contrib.messages',
|
||||
'django.contrib.staticfiles',
|
||||
'mypy_django_plugin.lib.tests.sample_django_project.myapp'
|
||||
]
|
||||
|
||||
MIDDLEWARE = [
|
||||
'django.middleware.security.SecurityMiddleware',
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.middleware.csrf.CsrfViewMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
||||
]
|
||||
|
||||
ROOT_URLCONF = 'sample_django_project.urls'
|
||||
|
||||
TEMPLATES = [
|
||||
{
|
||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||
'DIRS': [],
|
||||
'APP_DIRS': True,
|
||||
'OPTIONS': {
|
||||
'context_processors': [
|
||||
'django.template.context_processors.debug',
|
||||
'django.template.context_processors.request',
|
||||
'django.contrib.auth.context_processors.auth',
|
||||
'django.contrib.messages.context_processors.messages',
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
WSGI_APPLICATION = 'sample_django_project.wsgi.application'
|
||||
|
||||
|
||||
# Database
|
||||
# https://docs.djangoproject.com/en/2.2/ref/settings/#databases
|
||||
|
||||
DATABASES = {
|
||||
'default': {
|
||||
'ENGINE': 'django.db.backends.sqlite3',
|
||||
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Password validation
|
||||
# https://docs.djangoproject.com/en/2.2/ref/settings/#auth-password-validators
|
||||
|
||||
AUTH_PASSWORD_VALIDATORS = [
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Internationalization
|
||||
# https://docs.djangoproject.com/en/2.2/topics/i18n/
|
||||
|
||||
LANGUAGE_CODE = 'en-us'
|
||||
|
||||
TIME_ZONE = 'UTC'
|
||||
|
||||
USE_I18N = True
|
||||
|
||||
USE_L10N = True
|
||||
|
||||
USE_TZ = True
|
||||
|
||||
|
||||
# Static files (CSS, JavaScript, Images)
|
||||
# https://docs.djangoproject.com/en/2.2/howto/static-files/
|
||||
|
||||
STATIC_URL = '/static/'
|
||||
@@ -1,21 +0,0 @@
|
||||
"""sample_django_project URL Configuration
|
||||
|
||||
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||
https://docs.djangoproject.com/en/2.2/topics/http/urls/
|
||||
Examples:
|
||||
Function views
|
||||
1. Add an import: from my_app import views
|
||||
2. Add a URL to urlpatterns: path('', views.home, name='home')
|
||||
Class-based views
|
||||
1. Add an import: from other_app.views import Home
|
||||
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
|
||||
Including another URLconf
|
||||
1. Import the include() function: from django.urls import include, path
|
||||
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
||||
"""
|
||||
from django.contrib import admin
|
||||
from django.urls import path
|
||||
|
||||
urlpatterns = [
|
||||
path('admin/', admin.site.urls),
|
||||
]
|
||||
@@ -1,16 +0,0 @@
|
||||
"""
|
||||
WSGI config for sample_django_project project.
|
||||
|
||||
It exposes the WSGI callable as a module-level variable named ``application``.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/2.2/howto/deployment/wsgi/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from django.core.wsgi import get_wsgi_application
|
||||
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'sample_django_project.settings')
|
||||
|
||||
application = get_wsgi_application()
|
||||
@@ -1,14 +0,0 @@
|
||||
from mypy.options import Options
|
||||
|
||||
from mypy_django_plugin.lib.config import extract_app_model_aliases
|
||||
from mypy_django_plugin.main import DjangoPlugin
|
||||
|
||||
|
||||
def test_parse_django_settings():
|
||||
app_model_mapping = extract_app_model_aliases('mypy_django_plugin.lib.tests.sample_django_project.root.settings')
|
||||
assert app_model_mapping['myapp.MyModel'] == 'mypy_django_plugin.lib.tests.sample_django_project.myapp.models.MyModel'
|
||||
|
||||
|
||||
def test_instantiate_plugin_with_config():
|
||||
plugin = DjangoPlugin(Options())
|
||||
|
||||
@@ -8,7 +8,7 @@ from mypy.options import Options
|
||||
from mypy.plugin import ClassDefContext, FunctionContext, Plugin, MethodContext
|
||||
from mypy.types import Type as MypyType
|
||||
|
||||
from mypy_django_plugin_newsemanal.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.django.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.lib import fullnames, metadata
|
||||
from mypy_django_plugin_newsemanal.transformers import fields, settings, querysets, init_create
|
||||
from mypy_django_plugin_newsemanal.transformers.models import process_model_class
|
||||
@@ -28,7 +28,7 @@ def transform_model_class(ctx: ClassDefContext,
|
||||
process_model_class(ctx, django_context)
|
||||
|
||||
|
||||
def transform_manager_class(ctx: ClassDefContext) -> None:
|
||||
def add_new_manager_base(ctx: ClassDefContext) -> None:
|
||||
sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
|
||||
if sym is not None and isinstance(sym.node, TypeInfo):
|
||||
metadata.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1
|
||||
@@ -116,14 +116,15 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
if info.has_base(fullnames.FIELD_FULLNAME):
|
||||
return partial(fields.process_field_instantiation, django_context=self.django_context)
|
||||
|
||||
# if info.has_base(fullnames.MODEL_CLASS_FULLNAME):
|
||||
# return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)
|
||||
if info.has_base(fullnames.MODEL_CLASS_FULLNAME):
|
||||
return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)
|
||||
|
||||
# def get_method_hook(self, fullname: str
|
||||
# ) -> Optional[Callable[[MethodContext], Type]]:
|
||||
# class_name, _, method_name = fullname.rpartition('.')
|
||||
#
|
||||
#
|
||||
def get_method_hook(self, fullname: str
|
||||
) -> Optional[Callable[[MethodContext], Type]]:
|
||||
manager_classes = self._get_current_manager_bases()
|
||||
class_fullname, _, method_name = fullname.rpartition('.')
|
||||
if class_fullname in manager_classes and method_name == 'create':
|
||||
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
|
||||
|
||||
def get_base_class_hook(self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
@@ -131,7 +132,7 @@ class NewSemanalDjangoPlugin(Plugin):
|
||||
return partial(transform_model_class, django_context=self.django_context)
|
||||
|
||||
if fullname in self._get_current_manager_bases():
|
||||
return transform_manager_class
|
||||
return add_new_manager_base
|
||||
|
||||
# def get_attribute_hook(self, fullname: str
|
||||
# ) -> Optional[Callable[[AttributeContext], MypyType]]:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Optional, Tuple, cast
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import Expression, ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo
|
||||
from mypy.nodes import StrExpr, TypeInfo
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import AnyType, CallableType, Instance, TupleType, Type as MypyType, UnionType
|
||||
from mypy.types import AnyType, CallableType, Instance, Type as MypyType, UnionType
|
||||
|
||||
from mypy_django_plugin_newsemanal.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.lib import fullnames, helpers, metadata
|
||||
from mypy_django_plugin_newsemanal.django.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.lib import fullnames, helpers
|
||||
|
||||
|
||||
def extract_referred_to_type(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Instance]:
|
||||
|
||||
@@ -1,35 +1,69 @@
|
||||
from typing import cast
|
||||
from typing import List, Tuple, Type, Union
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import Argument, Var, ARG_NAMED
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Type as MypyType, Instance
|
||||
from django.db.models.base import Model
|
||||
from mypy.plugin import FunctionContext, MethodContext
|
||||
from mypy.types import Instance, Type as MypyType
|
||||
|
||||
from mypy_django_plugin_newsemanal.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.lib import helpers
|
||||
from mypy_django_plugin_newsemanal.django.context import DjangoContext
|
||||
|
||||
|
||||
def get_actual_types(ctx: Union[MethodContext, FunctionContext], expected_keys: List[str]) -> List[Tuple[str, MypyType]]:
|
||||
actual_types = []
|
||||
# positionals
|
||||
for pos, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
|
||||
if actual_name is None:
|
||||
if ctx.callee_arg_names[0] == 'kwargs':
|
||||
# unpacked dict as kwargs is not supported
|
||||
continue
|
||||
actual_name = expected_keys[pos]
|
||||
actual_types.append((actual_name, actual_type))
|
||||
# kwargs
|
||||
if len(ctx.callee_arg_names) > 1:
|
||||
for actual_name, actual_type in zip(ctx.arg_names[1], ctx.arg_types[1]):
|
||||
if actual_name is None:
|
||||
# unpacked dict as kwargs is not supported
|
||||
continue
|
||||
actual_types.append((actual_name, actual_type))
|
||||
return actual_types
|
||||
|
||||
|
||||
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
|
||||
model_cls: Type[Model], method: str) -> MypyType:
|
||||
expected_types = django_context.get_expected_types(ctx.api, model_cls, method)
|
||||
expected_keys = [key for key in expected_types.keys() if key != 'pk']
|
||||
|
||||
for actual_name, actual_type in get_actual_types(ctx, expected_keys):
|
||||
if actual_name not in expected_types:
|
||||
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
|
||||
model_cls.__name__),
|
||||
ctx.context)
|
||||
continue
|
||||
ctx.api.check_subtype(actual_type, expected_types[actual_name],
|
||||
ctx.context,
|
||||
'Incompatible type for "{}" of "{}"'.format(actual_name,
|
||||
model_cls.__name__),
|
||||
'got', 'expected')
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def redefine_and_typecheck_model_init(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
|
||||
assert isinstance(ctx.default_return_type, Instance)
|
||||
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
model_fullname = ctx.default_return_type.type.fullname()
|
||||
model_cls = django_context.get_model_class_by_fullname(model_fullname)
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
model_info = ctx.default_return_type.type
|
||||
model_cls = django_context.get_model_class_by_fullname(model_info.fullname())
|
||||
return typecheck_model_method(ctx, django_context, model_cls, '__init__')
|
||||
|
||||
# expected_types = {}
|
||||
# for field in model_cls._meta.get_fields():
|
||||
# field_fullname = helpers.get_class_fullname(field.__class__)
|
||||
# field_info = api.lookup_typeinfo(field_fullname)
|
||||
# field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
|
||||
# is_nullable=False)
|
||||
# field_kwarg = Argument(variable=Var(field.attname, field_set_type),
|
||||
# type_annotation=field_set_type,
|
||||
# initializer=None,
|
||||
# kind=ARG_NAMED)
|
||||
# expected_types[field.attname] = field_set_type
|
||||
# for field_name, field in model_cls._meta.fields_map.items():
|
||||
# print()
|
||||
|
||||
# print()
|
||||
return ctx.default_return_type
|
||||
def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||
isinstance(ctx.default_return_type, Instance)
|
||||
|
||||
model_fullname = ctx.default_return_type.type.fullname()
|
||||
model_cls = django_context.get_model_class_by_fullname(model_fullname)
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
return typecheck_model_method(ctx, django_context, model_cls, 'create')
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
import dataclasses
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Optional, Type, cast
|
||||
from typing import cast
|
||||
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.fields.related import ForeignKey
|
||||
from mypy.newsemanal.semanal import NewSemanticAnalyzer
|
||||
from mypy.nodes import ARG_NAMED_OPT, Argument, ClassDef, MDEF, SymbolTableNode, TypeInfo, Var
|
||||
from mypy.nodes import ClassDef, MDEF, SymbolTableNode, TypeInfo, Var
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.plugins import common
|
||||
from mypy.types import AnyType, Instance, NoneType, Type as MypyType, UnionType
|
||||
from mypy.types import Instance
|
||||
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.db.models.fields import CharField, Field
|
||||
from mypy_django_plugin_newsemanal.context import DjangoContext
|
||||
from django.db.models.fields import Field
|
||||
from mypy_django_plugin_newsemanal.django.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.lib import helpers
|
||||
from mypy_django_plugin_newsemanal.transformers import fields
|
||||
from mypy_django_plugin_newsemanal.transformers.fields import get_field_descriptor_types
|
||||
@@ -52,101 +49,6 @@ class ModelClassInitializer(metaclass=ABCMeta):
|
||||
var.is_initialized_in_class = True
|
||||
var.is_inferred = True
|
||||
self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True)
|
||||
# assert self.model_classdef.info == self.api.type
|
||||
# self.api.add_symbol_table_node(name, SymbolTableNode(MDEF, var, plugin_generated=True))
|
||||
|
||||
def convert_any_to_type(self, typ: MypyType, referred_to_type: MypyType) -> MypyType:
|
||||
if isinstance(typ, UnionType):
|
||||
converted_items = []
|
||||
for item in typ.items:
|
||||
converted_items.append(self.convert_any_to_type(item, referred_to_type))
|
||||
return UnionType.make_union(converted_items,
|
||||
line=typ.line, column=typ.column)
|
||||
if isinstance(typ, Instance):
|
||||
args = []
|
||||
for default_arg in typ.args:
|
||||
if isinstance(default_arg, AnyType):
|
||||
args.append(referred_to_type)
|
||||
else:
|
||||
args.append(default_arg)
|
||||
return helpers.reparametrize_instance(typ, args)
|
||||
|
||||
if isinstance(typ, AnyType):
|
||||
return referred_to_type
|
||||
|
||||
return typ
|
||||
|
||||
def get_field_set_type(self, field: Field, method: str) -> MypyType:
|
||||
target_field = field
|
||||
if isinstance(field, ForeignKey):
|
||||
target_field = field.target_field
|
||||
field_fullname = helpers.get_class_fullname(target_field.__class__)
|
||||
field_info = self.lookup_typeinfo_or_incomplete_defn_error(field_fullname)
|
||||
field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
|
||||
is_nullable=self.get_field_nullability(field, method))
|
||||
if isinstance(target_field, ArrayField):
|
||||
argument_field_type = self.get_field_set_type(target_field.base_field, method)
|
||||
field_set_type = self.convert_any_to_type(field_set_type, argument_field_type)
|
||||
return field_set_type
|
||||
|
||||
def get_field_nullability(self, field: Field, method: Optional[str]) -> bool:
|
||||
nullable = field.null
|
||||
if not nullable and isinstance(field, CharField) and field.blank:
|
||||
return True
|
||||
if method == '__init__':
|
||||
if field.primary_key or isinstance(field, ForeignKey):
|
||||
return True
|
||||
return nullable
|
||||
|
||||
def get_field_kind(self, field: Field, method: str):
|
||||
if method == '__init__':
|
||||
# all arguments are optional in __init__
|
||||
return ARG_NAMED_OPT
|
||||
|
||||
def get_primary_key_field(self, model_cls: Type[Model]) -> Field:
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, Field):
|
||||
if field.primary_key:
|
||||
return field
|
||||
raise ValueError('No primary key defined')
|
||||
|
||||
def make_field_kwarg(self, name: str, field: Field, method: str) -> Argument:
|
||||
field_set_type = self.get_field_set_type(field, method)
|
||||
kind = self.get_field_kind(field, method)
|
||||
field_kwarg = Argument(variable=Var(name, field_set_type),
|
||||
type_annotation=field_set_type,
|
||||
initializer=None,
|
||||
kind=kind)
|
||||
return field_kwarg
|
||||
|
||||
def get_field_kwargs(self, model_cls: Type[Model], method: str):
|
||||
field_kwargs = []
|
||||
if method == '__init__':
|
||||
# add primary key `pk`
|
||||
primary_key_field = self.get_primary_key_field(model_cls)
|
||||
field_kwarg = self.make_field_kwarg('pk', primary_key_field, method)
|
||||
field_kwargs.append(field_kwarg)
|
||||
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, Field):
|
||||
field_kwarg = self.make_field_kwarg(field.attname, field, method)
|
||||
field_kwargs.append(field_kwarg)
|
||||
|
||||
if isinstance(field, ForeignKey):
|
||||
attname = field.name
|
||||
related_model_fullname = helpers.get_class_fullname(field.related_model)
|
||||
model_info = self.lookup_typeinfo_or_incomplete_defn_error(related_model_fullname)
|
||||
is_nullable = self.get_field_nullability(field, method)
|
||||
field_set_type = Instance(model_info, [])
|
||||
if is_nullable:
|
||||
field_set_type = helpers.make_optional(field_set_type)
|
||||
kind = self.get_field_kind(field, method)
|
||||
field_kwarg = Argument(variable=Var(attname, field_set_type),
|
||||
type_annotation=field_set_type,
|
||||
initializer=None,
|
||||
kind=kind)
|
||||
field_kwargs.append(field_kwarg)
|
||||
return field_kwargs
|
||||
|
||||
@abstractmethod
|
||||
def run(self) -> None:
|
||||
@@ -198,9 +100,9 @@ class AddRelatedModelsId(ModelClassInitializer):
|
||||
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, ForeignKey):
|
||||
rel_primary_key_field = self.get_primary_key_field(field.related_model)
|
||||
rel_primary_key_field = self.django_context.get_primary_key_field(field.related_model)
|
||||
field_info = self.lookup_field_typeinfo_or_incomplete_defn_error(rel_primary_key_field)
|
||||
is_nullable = self.get_field_nullability(field, None)
|
||||
is_nullable = self.django_context.fields_context.get_field_nullability(field, None)
|
||||
set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
|
||||
self.add_new_node_to_model_class(field.attname,
|
||||
Instance(field_info, [set_type, get_type]))
|
||||
@@ -228,16 +130,6 @@ class AddManagers(ModelClassInitializer):
|
||||
self.add_new_node_to_model_class('_default_manager', default_manager)
|
||||
|
||||
|
||||
class AddInitMethod(ModelClassInitializer):
|
||||
def run(self):
|
||||
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.info.fullname())
|
||||
if model_cls is None:
|
||||
return
|
||||
|
||||
field_kwargs = self.get_field_kwargs(model_cls, '__init__')
|
||||
common.add_method(self.ctx, '__init__', field_kwargs, NoneType())
|
||||
|
||||
|
||||
def process_model_class(ctx: ClassDefContext,
|
||||
django_context: DjangoContext) -> None:
|
||||
initializers = [
|
||||
@@ -245,11 +137,10 @@ def process_model_class(ctx: ClassDefContext,
|
||||
AddDefaultPrimaryKey,
|
||||
AddRelatedModelsId,
|
||||
AddManagers,
|
||||
AddInitMethod
|
||||
]
|
||||
for initializer_cls in initializers:
|
||||
try:
|
||||
initializer_cls.from_ctx(ctx, django_context).run()
|
||||
except helpers.IncompleteDefnException:
|
||||
if not ctx.api.final_iteration:
|
||||
ctx.api.defer()
|
||||
ctx.api.defer()
|
||||
|
||||
@@ -2,7 +2,7 @@ from mypy.nodes import TypeInfo
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Type as MypyType, TypeType, Instance
|
||||
|
||||
from mypy_django_plugin_newsemanal.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.django.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.lib import helpers
|
||||
|
||||
|
||||
|
||||
@@ -58,9 +58,9 @@
|
||||
- case: optional_id_fields_for_create_is_error
|
||||
main: |
|
||||
from myapp.models import Publisher, Book
|
||||
Book.objects.create(id=None) # E: Incompatible type for "id" of "MyModel" (got "None", expected "int")
|
||||
Book.objects.create(publisher=None) # E: Incompatible type for "id" of "MyModel" (got "None", expected "int")
|
||||
Book.objects.create(publisher_id=None) # E: Incompatible type for "id" of "MyModel" (got "None", expected "int")
|
||||
Book.objects.create(id=None) # E: Incompatible type for "id" of "Book" (got "None", expected "Union[Combinable, int, str]")
|
||||
Book.objects.create(publisher=None) # E: Incompatible type for "publisher" of "Book" (got "None", expected "Union[Publisher, Combinable]")
|
||||
Book.objects.create(publisher_id=None) # E: Incompatible type for "publisher_id" of "Book" (got "None", expected "Union[Combinable, int, str]")
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
@@ -80,7 +80,7 @@
|
||||
MyModel.objects.create(id=None)
|
||||
|
||||
from myapp.models import MyModel2
|
||||
MyModel2(id=None) # E: Incompatible type for "id" of "MyModel2" (got "None", expected "Union[float, int, str, Combinable]")
|
||||
MyModel2(id=None)
|
||||
MyModel2.objects.create(id=None) # E: Incompatible type for "id" of "MyModel2" (got "None", expected "Union[float, int, str, Combinable]")
|
||||
installed_apps:
|
||||
- myapp
|
||||
@@ -94,4 +94,4 @@
|
||||
class MyModel(models.Model):
|
||||
id = models.IntegerField(primary_key=True, default=return_int)
|
||||
class MyModel2(models.Model):
|
||||
id = models.IntegerField(primary_key=True, default=None)
|
||||
id = models.IntegerField(primary_key=True)
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
from myapp.models import MyUser
|
||||
user = MyUser(name=1, age=12)
|
||||
out: |
|
||||
main:2: error: Unexpected keyword argument "name" for "MyUser"
|
||||
main:2: error: Unexpected keyword argument "age" for "MyUser"
|
||||
main:2: error: Unexpected attribute "name" for model "MyUser"
|
||||
main:2: error: Unexpected attribute "age" for model "MyUser"
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
@@ -16,12 +16,28 @@
|
||||
class MyUser(models.Model):
|
||||
pass
|
||||
|
||||
- case: plain_function_which_returns_model
|
||||
main: |
|
||||
from myapp.models import MyUser
|
||||
def func(i: int) -> MyUser:
|
||||
pass
|
||||
func("hello") # E: Argument 1 to "func" has incompatible type "str"; expected "int"
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
class MyUser(models.Model):
|
||||
pass
|
||||
|
||||
- case: arguments_to_init_from_class_incompatible_type
|
||||
main: |
|
||||
from myapp.models import MyUser
|
||||
user = MyUser(name='hello', age=[])
|
||||
out: |
|
||||
main:2: error: Argument "age" to "MyUser" has incompatible type "List[<nothing>]"; expected "Union[float, int, str, Combinable]"
|
||||
main:2: error: Incompatible type for "age" of "MyUser" (got "List[Any]", expected "Union[float, int, str, Combinable]")
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
@@ -91,7 +107,7 @@
|
||||
- case: typechecking_of_pk
|
||||
main: |
|
||||
from myapp.models import MyUser1
|
||||
user = MyUser1(pk=[]) # E: Argument "pk" to "MyUser1" has incompatible type "List[<nothing>]"; expected "Union[float, int, str, Combinable, None]"
|
||||
user = MyUser1(pk=[]) # E: Incompatible type for "pk" of "MyUser1" (got "List[Any]", expected "Union[float, int, str, Combinable, None]")
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
@@ -110,8 +126,8 @@
|
||||
|
||||
from myapp.models import Publisher, PublisherDatetime, Book
|
||||
Book(publisher_id=1, publisher_dt_id=now)
|
||||
Book(publisher_id=[], publisher_dt_id=now) # E: Argument "publisher_id" to "Book" has incompatible type "List[<nothing>]"; expected "Union[Combinable, int, str, None]"
|
||||
Book(publisher_id=1, publisher_dt_id=1) # E: Argument "publisher_dt_id" to "Book" has incompatible type "int"; expected "Union[str, date, Combinable, None]"
|
||||
Book(publisher_id=[], publisher_dt_id=now) # E: Incompatible type for "publisher_id" of "Book" (got "List[Any]", expected "Union[Combinable, int, str, None]")
|
||||
Book(publisher_id=1, publisher_dt_id=1) # E: Incompatible type for "publisher_dt_id" of "Book" (got "int", expected "Union[str, date, Combinable, None]")
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
@@ -139,7 +155,7 @@
|
||||
class NotAValid:
|
||||
pass
|
||||
array_val3: List[NotAValid] = [NotAValid()]
|
||||
MyModel(array=array_val3) # E: Argument "array" to "MyModel" has incompatible type "List[NotAValid]"; expected "Union[Sequence[Union[float, int, str, Combinable]], Combinable]"
|
||||
MyModel(array=array_val3) # E: Incompatible type for "array" of "MyModel" (got "List[NotAValid]", expected "Union[Sequence[Union[float, int, str, Combinable]], Combinable]")
|
||||
installed_apps:
|
||||
- myapp
|
||||
files:
|
||||
|
||||
Reference in New Issue
Block a user