mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-06 20:24:31 +08:00
latest changes
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,3 +4,4 @@ out/
|
||||
/test_sqlite.py
|
||||
/django
|
||||
.idea/
|
||||
.mypy_cache/
|
||||
@@ -1,3 +1,3 @@
|
||||
pytest_plugins = [
|
||||
'test.data'
|
||||
'test.pytest_plugin'
|
||||
]
|
||||
7
django-stubs/apps/__init__.pyi
Normal file
7
django-stubs/apps/__init__.pyi
Normal file
@@ -0,0 +1,7 @@
|
||||
from .config import (
|
||||
AppConfig as AppConfig
|
||||
)
|
||||
|
||||
from .registry import (
|
||||
apps as apps
|
||||
)
|
||||
26
django-stubs/apps/config.pyi
Normal file
26
django-stubs/apps/config.pyi
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Any, Iterator, Type
|
||||
|
||||
from django.db.models.base import Model
|
||||
|
||||
MODELS_MODULE_NAME: str
|
||||
|
||||
class AppConfig:
|
||||
name: str = ...
|
||||
module: Any = ...
|
||||
apps: None = ...
|
||||
label: str = ...
|
||||
verbose_name: str = ...
|
||||
path: str = ...
|
||||
models_module: None = ...
|
||||
models: None = ...
|
||||
def __init__(self, app_name: str, app_module: None) -> None: ...
|
||||
@classmethod
|
||||
def create(cls, entry: str) -> AppConfig: ...
|
||||
def get_model(
|
||||
self, model_name: str, require_ready: bool = ...
|
||||
) -> Type[Model]: ...
|
||||
def get_models(
|
||||
self, include_auto_created: bool = ..., include_swapped: bool = ...
|
||||
) -> Iterator[Type[Model]]: ...
|
||||
def import_models(self) -> None: ...
|
||||
def ready(self) -> None: ...
|
||||
62
django-stubs/apps/registry.pyi
Normal file
62
django-stubs/apps/registry.pyi
Normal file
@@ -0,0 +1,62 @@
|
||||
import collections
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type, Union, Iterable
|
||||
|
||||
from django.apps.config import AppConfig
|
||||
from django.db.migrations.state import AppConfigStub
|
||||
from django.db.models.base import Model
|
||||
|
||||
from .config import AppConfig
|
||||
|
||||
|
||||
class Apps:
|
||||
all_models: collections.defaultdict = ...
|
||||
app_configs: collections.OrderedDict = ...
|
||||
stored_app_configs: List[Any] = ...
|
||||
apps_ready: bool = ...
|
||||
loading: bool = ...
|
||||
def __init__(
|
||||
self,
|
||||
installed_apps: Optional[
|
||||
Union[List[AppConfigStub], List[str], Tuple]
|
||||
] = ...,
|
||||
) -> None: ...
|
||||
models_ready: bool = ...
|
||||
ready: bool = ...
|
||||
def populate(
|
||||
self, installed_apps: Union[List[AppConfigStub], List[str], Tuple] = ...
|
||||
) -> None: ...
|
||||
def check_apps_ready(self) -> None: ...
|
||||
def check_models_ready(self) -> None: ...
|
||||
def get_app_configs(self) -> Iterable[AppConfig]: ...
|
||||
def get_app_config(self, app_label: str) -> AppConfig: ...
|
||||
def get_models(
|
||||
self, include_auto_created: bool = ..., include_swapped: bool = ...
|
||||
) -> List[Type[Model]]: ...
|
||||
def get_model(
|
||||
self,
|
||||
app_label: str,
|
||||
model_name: Optional[str] = ...,
|
||||
require_ready: bool = ...,
|
||||
) -> Type[Model]: ...
|
||||
def register_model(self, app_label: str, model: Type[Model]) -> None: ...
|
||||
def is_installed(self, app_name: str) -> bool: ...
|
||||
def get_containing_app_config(
|
||||
self, object_name: str
|
||||
) -> Optional[AppConfig]: ...
|
||||
def get_registered_model(
|
||||
self, app_label: str, model_name: str
|
||||
) -> Type[Model]: ...
|
||||
def get_swappable_settings_name(self, to_string: str) -> Optional[str]: ...
|
||||
def set_available_apps(self, available: List[str]) -> None: ...
|
||||
def unset_available_apps(self) -> None: ...
|
||||
def set_installed_apps(
|
||||
self, installed: Union[List[str], Tuple[str]]
|
||||
) -> None: ...
|
||||
def unset_installed_apps(self) -> None: ...
|
||||
def clear_cache(self) -> None: ...
|
||||
def lazy_model_operation(
|
||||
self, function: Callable, *model_keys: Any
|
||||
) -> None: ...
|
||||
def do_pending_operations(self, model: Type[Model]) -> None: ...
|
||||
|
||||
apps: Apps
|
||||
@@ -1,2 +1,18 @@
|
||||
from typing import Any
|
||||
|
||||
from .utils import (ProgrammingError as ProgrammingError,
|
||||
IntegrityError as IntegrityError)
|
||||
IntegrityError as IntegrityError,
|
||||
OperationalError as OperationalError,
|
||||
DatabaseError as DatabaseError,
|
||||
DataError as DataError,
|
||||
NotSupportedError as NotSupportedError)
|
||||
|
||||
connections: Any
|
||||
router: Any
|
||||
|
||||
class DefaultConnectionProxy:
|
||||
def __getattr__(self, item: str) -> Any: ...
|
||||
def __setattr__(self, name: str, value: Any) -> None: ...
|
||||
def __delattr__(self, name: str) -> None: ...
|
||||
|
||||
connection: Any
|
||||
|
||||
@@ -3,14 +3,21 @@ from .base import Model as Model
|
||||
from .fields import (AutoField as AutoField,
|
||||
IntegerField as IntegerField,
|
||||
SmallIntegerField as SmallIntegerField,
|
||||
BigIntegerField as BigIntegerField,
|
||||
CharField as CharField,
|
||||
Field as Field,
|
||||
SlugField as SlugField,
|
||||
TextField as TextField,
|
||||
BooleanField as BooleanField)
|
||||
BooleanField as BooleanField,
|
||||
FileField as FileField,
|
||||
DateField as DateField,
|
||||
DateTimeField as DateTimeField,
|
||||
IPAddressField as IPAddressField,
|
||||
GenericIPAddressField as GenericIPAddressField)
|
||||
|
||||
from .fields.related import (ForeignKey as ForeignKey,
|
||||
OneToOneField as OneToOneField)
|
||||
OneToOneField as OneToOneField,
|
||||
ManyToManyField as ManyToManyField)
|
||||
|
||||
from .deletion import (CASCADE as CASCADE,
|
||||
SET_DEFAULT as SET_DEFAULT,
|
||||
@@ -24,4 +31,11 @@ from .query_utils import Q as Q
|
||||
|
||||
from .lookups import Lookup as Lookup
|
||||
|
||||
from .expressions import F as F
|
||||
from .expressions import (F as F,
|
||||
Subquery as Subquery,
|
||||
Exists as Exists,
|
||||
OrderBy as OrderBy,
|
||||
OuterRef as OuterRef)
|
||||
|
||||
from .manager import (BaseManager as BaseManager,
|
||||
Manager as Manager)
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime, timedelta
|
||||
from typing import (Any, Callable, Dict, Iterator, List, Optional, Set, Tuple,
|
||||
Type, Union)
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.lookups import Lookup
|
||||
from django.db.models.sql.compiler import SQLCompiler
|
||||
@@ -180,4 +181,50 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
|
||||
|
||||
|
||||
class F(Combinable):
|
||||
name: str
|
||||
def __init__(self, name: str): ...
|
||||
def resolve_expression(
|
||||
self,
|
||||
query: Any = ...,
|
||||
allow_joins: bool = ...,
|
||||
reuse: Optional[Set[str]] = ...,
|
||||
summarize: bool = ...,
|
||||
for_save: bool = ...,
|
||||
) -> Expression: ...
|
||||
|
||||
|
||||
class OuterRef(F): ...
|
||||
|
||||
class Subquery(Expression):
|
||||
template: str = ...
|
||||
queryset: QuerySet = ...
|
||||
extra: Dict[Any, Any] = ...
|
||||
def __init__(
|
||||
self,
|
||||
queryset: QuerySet,
|
||||
output_field: Optional[Field] = ...,
|
||||
**extra: Any
|
||||
) -> None: ...
|
||||
|
||||
class Exists(Subquery):
|
||||
extra: Dict[Any, Any]
|
||||
template: str = ...
|
||||
negated: bool = ...
|
||||
def __init__(
|
||||
self, *args: Any, negated: bool = ..., **kwargs: Any
|
||||
) -> None: ...
|
||||
def __invert__(self) -> Exists: ...
|
||||
|
||||
class OrderBy(BaseExpression):
|
||||
template: str = ...
|
||||
nulls_first: bool = ...
|
||||
nulls_last: bool = ...
|
||||
descending: bool = ...
|
||||
expression: Expression = ...
|
||||
def __init__(
|
||||
self,
|
||||
expression: Combinable,
|
||||
descending: bool = ...,
|
||||
nulls_first: bool = ...,
|
||||
nulls_last: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from django.db.models.query_utils import RegisterLookupMixin
|
||||
|
||||
@@ -7,6 +7,7 @@ class Field(RegisterLookupMixin):
|
||||
def __init__(self,
|
||||
primary_key: bool = False,
|
||||
**kwargs): ...
|
||||
|
||||
def __get__(self, instance, owner) -> Any: ...
|
||||
|
||||
|
||||
@@ -14,8 +15,10 @@ class IntegerField(Field):
|
||||
def __get__(self, instance, owner) -> int: ...
|
||||
|
||||
|
||||
class SmallIntegerField(IntegerField):
|
||||
pass
|
||||
class SmallIntegerField(IntegerField): ...
|
||||
|
||||
|
||||
class BigIntegerField(IntegerField): ...
|
||||
|
||||
|
||||
class AutoField(Field):
|
||||
@@ -26,11 +29,11 @@ class CharField(Field):
|
||||
def __init__(self,
|
||||
max_length: int,
|
||||
**kwargs): ...
|
||||
|
||||
def __get__(self, instance, owner) -> str: ...
|
||||
|
||||
|
||||
class SlugField(CharField):
|
||||
pass
|
||||
class SlugField(CharField): ...
|
||||
|
||||
|
||||
class TextField(Field):
|
||||
@@ -39,3 +42,29 @@ class TextField(Field):
|
||||
|
||||
class BooleanField(Field):
|
||||
def __get__(self, instance, owner) -> bool: ...
|
||||
|
||||
|
||||
class FileField(Field): ...
|
||||
|
||||
|
||||
class IPAddressField(Field): ...
|
||||
|
||||
|
||||
class GenericIPAddressField(Field):
|
||||
default_error_messages: Any = ...
|
||||
unpack_ipv4: Any = ...
|
||||
protocol: Any = ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name: Optional[Any] = ...,
|
||||
name: Optional[Any] = ...,
|
||||
protocol: str = ...,
|
||||
unpack_ipv4: bool = ...,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> None: ...
|
||||
|
||||
class DateField(Field): ...
|
||||
|
||||
class DateTimeField(DateField): ...
|
||||
@@ -22,3 +22,12 @@ class OneToOneField(Field, Generic[_T]):
|
||||
related_name: str = ...,
|
||||
**kwargs): ...
|
||||
def __get__(self, instance, owner) -> _T: ...
|
||||
|
||||
|
||||
class ManyToManyField(Field, Generic[_T]):
|
||||
def __init__(self,
|
||||
to: Union[Type[_T], str],
|
||||
on_delete: Any,
|
||||
related_name: str = ...,
|
||||
**kwargs): ...
|
||||
def __get__(self, instance, owner) -> _T: ...
|
||||
|
||||
140
django-stubs/db/models/manager.pyi
Normal file
140
django-stubs/db/models/manager.pyi
Normal file
@@ -0,0 +1,140 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, TypeVar, Set, Generic, Iterator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from django.db.models import Q
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.query import QuerySet, RawQuerySet
|
||||
|
||||
_T = TypeVar('_T', bound=Model)
|
||||
|
||||
|
||||
class BaseManager:
|
||||
creation_counter: int = ...
|
||||
auto_created: bool = ...
|
||||
use_in_migrations: bool = ...
|
||||
def __new__(cls: Type[BaseManager], *args: Any, **kwargs: Any) -> BaseManager: ...
|
||||
model: Any = ...
|
||||
name: Any = ...
|
||||
def __init__(self) -> None: ...
|
||||
def deconstruct(self) -> Tuple[bool, str, None, Tuple, Dict[str, int]]: ...
|
||||
def check(self, **kwargs: Any) -> List[Any]: ...
|
||||
@classmethod
|
||||
def from_queryset(
|
||||
cls, queryset_class: Any, class_name: Optional[Any] = ...
|
||||
): ...
|
||||
def contribute_to_class(self, model: Type[Model], name: str) -> None: ...
|
||||
def db_manager(
|
||||
self,
|
||||
using: Optional[str] = ...,
|
||||
hints: Optional[Dict[str, Model]] = ...,
|
||||
) -> Manager: ...
|
||||
@property
|
||||
def db(self) -> str: ...
|
||||
def get_queryset(self) -> QuerySet: ...
|
||||
def all(self) -> QuerySet: ...
|
||||
def __eq__(self, other: Optional[Any]) -> bool: ...
|
||||
def __hash__(self): ...
|
||||
|
||||
class Manager(Generic[_T]):
|
||||
def exists(self) -> bool: ...
|
||||
def explain(
|
||||
self, *, format: Optional[Any] = ..., **options: Any
|
||||
) -> str: ...
|
||||
def raw(
|
||||
self,
|
||||
raw_query: str,
|
||||
params: Optional[
|
||||
Union[
|
||||
Dict[str, str],
|
||||
List[datetime],
|
||||
List[Decimal],
|
||||
List[str],
|
||||
Set[str],
|
||||
Tuple[int],
|
||||
]
|
||||
] = ...,
|
||||
translations: Optional[Dict[str, str]] = ...,
|
||||
using: None = ...,
|
||||
) -> RawQuerySet: ...
|
||||
def values(self, *fields: Any, **expressions: Any) -> QuerySet: ...
|
||||
def values_list(
|
||||
self, *fields: Any, flat: bool = ..., named: bool = ...
|
||||
) -> QuerySet: ...
|
||||
def dates(
|
||||
self, field_name: str, kind: str, order: str = ...
|
||||
) -> QuerySet: ...
|
||||
def datetimes(
|
||||
self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...
|
||||
) -> QuerySet: ...
|
||||
def none(self) -> QuerySet[_T]: ...
|
||||
def all(self) -> QuerySet[_T]: ...
|
||||
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
|
||||
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
|
||||
def complex_filter(
|
||||
self,
|
||||
filter_obj: Union[
|
||||
Dict[str, datetime], Dict[str, QuerySet], Q, MagicMock
|
||||
],
|
||||
) -> QuerySet[_T]: ...
|
||||
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
|
||||
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def select_for_update(
|
||||
self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...
|
||||
) -> QuerySet: ...
|
||||
|
||||
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
|
||||
|
||||
def extra(
|
||||
self,
|
||||
select: Optional[
|
||||
Union[Dict[str, int], Dict[str, str], OrderedDict]
|
||||
] = ...,
|
||||
where: Optional[List[str]] = ...,
|
||||
params: Optional[Union[List[int], List[str]]] = ...,
|
||||
tables: Optional[List[str]] = ...,
|
||||
order_by: Optional[Union[List[str], Tuple[str]]] = ...,
|
||||
select_params: Optional[Union[List[int], List[str], Tuple[int]]] = ...,
|
||||
) -> QuerySet[_T]: ...
|
||||
|
||||
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
|
||||
|
||||
def aggregate(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> Dict[str, Optional[Union[datetime, float]]]: ...
|
||||
|
||||
def count(self) -> int: ...
|
||||
|
||||
def get(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> _T: ...
|
||||
|
||||
def create(self, **kwargs: Any) -> _T: ...
|
||||
|
||||
|
||||
class ManagerDescriptor:
|
||||
manager: Manager = ...
|
||||
def __init__(self, manager: Manager) -> None: ...
|
||||
def __get__(
|
||||
self, instance: Optional[Model], cls: Type[Model] = ...
|
||||
) -> Manager: ...
|
||||
|
||||
class EmptyManager(Manager):
|
||||
creation_counter: int
|
||||
name: None
|
||||
model: Optional[Type[Model]] = ...
|
||||
def __init__(self, model: Type[Model]) -> None: ...
|
||||
def get_queryset(self) -> QuerySet: ...
|
||||
@@ -7,3 +7,8 @@ from .conf import (include as include,
|
||||
from .resolvers import (ResolverMatch as ResolverMatch,
|
||||
get_ns_resolver as get_ns_resolver,
|
||||
get_resolver as get_resolver)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from .converters import (
|
||||
register_converter as register_converter
|
||||
)
|
||||
30
django-stubs/utils/baseconv.pyi
Normal file
30
django-stubs/utils/baseconv.pyi
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Any, Tuple, Union
|
||||
|
||||
BASE2_ALPHABET: str
|
||||
BASE16_ALPHABET: str
|
||||
BASE56_ALPHABET: str
|
||||
BASE36_ALPHABET: str
|
||||
BASE62_ALPHABET: str
|
||||
BASE64_ALPHABET: Any
|
||||
|
||||
class BaseConverter:
|
||||
decimal_digits: str = ...
|
||||
sign: str = ...
|
||||
digits: str = ...
|
||||
def __init__(self, digits: str, sign: str = ...) -> None: ...
|
||||
def encode(self, i: int) -> str: ...
|
||||
def decode(self, s: str) -> int: ...
|
||||
def convert(
|
||||
self,
|
||||
number: Union[int, str],
|
||||
from_digits: str,
|
||||
to_digits: str,
|
||||
sign: str,
|
||||
) -> Tuple[int, str]: ...
|
||||
|
||||
base2: Any
|
||||
base16: Any
|
||||
base36: Any
|
||||
base56: Any
|
||||
base62: Any
|
||||
base64: Any
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Dict, Optional, NamedTuple
|
||||
import typing
|
||||
from typing import Dict, Optional, NamedTuple, Any
|
||||
|
||||
from mypy.semanal import SemanticAnalyzerPass2
|
||||
from mypy.types import Type
|
||||
from mypy.nodes import SymbolTableNode, Var, Expression
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Instance, UnionType, NoneTyp
|
||||
from mypy.types import Type, Instance, UnionType, NoneTyp
|
||||
|
||||
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
|
||||
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
|
||||
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
|
||||
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
|
||||
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
|
||||
|
||||
|
||||
def create_new_symtable_node(name: str, kind: int, instance: Instance) -> SymbolTableNode:
|
||||
@@ -26,12 +26,10 @@ Argument = NamedTuple('Argument', fields=[
|
||||
|
||||
|
||||
def get_call_signature_or_none(ctx: FunctionContext) -> Optional[Dict[str, Argument]]:
|
||||
arg_names = ctx.context.arg_names
|
||||
|
||||
result: Dict[str, Argument] = {}
|
||||
positional_args_only = []
|
||||
positional_arg_types_only = []
|
||||
for arg, arg_name, arg_type in zip(ctx.args, arg_names, ctx.arg_types):
|
||||
for arg, arg_name, arg_type in zip(ctx.args, ctx.arg_names, ctx.arg_types):
|
||||
if arg_name is None:
|
||||
positional_args_only.append(arg)
|
||||
positional_arg_types_only.append(arg_type)
|
||||
@@ -65,3 +63,7 @@ def make_required(typ: Type) -> Type:
|
||||
return typ
|
||||
items = [item for item in typ.items if not isinstance(item, NoneTyp)]
|
||||
return UnionType.make_union(items)
|
||||
|
||||
|
||||
def get_obj_type_name(typ: typing.Type) -> str:
|
||||
return typ.__module__ + '.' + typ.__qualname__
|
||||
|
||||
@@ -1,46 +1,79 @@
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, List
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.conf import Settings
|
||||
from mypy import build
|
||||
from mypy.build import BuildManager
|
||||
from mypy.options import Options
|
||||
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
|
||||
from mypy.types import Type
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
from mypy_django_plugin import helpers, monkeypatch
|
||||
from mypy_django_plugin.plugins.objects_queryset import set_objects_queryset_to_model_class
|
||||
from mypy_django_plugin.plugins.postgres_fields import determine_type_of_array_field
|
||||
from mypy_django_plugin.plugins.related_fields import set_related_name_instance_for_onetoonefield, \
|
||||
set_related_name_manager_for_foreign_key, set_fieldname_attrs_for_related_fields
|
||||
from mypy_django_plugin.plugins.related_fields import OneToOneFieldHook, \
|
||||
ForeignKeyHook, set_fieldname_attrs_for_related_fields
|
||||
from mypy_django_plugin.plugins.setup_settings import DjangoConfSettingsInitializerHook
|
||||
|
||||
|
||||
base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
|
||||
|
||||
|
||||
def transform_model_class(ctx: ClassDefContext) -> None:
|
||||
class TransformModelClassHook(object):
|
||||
def __init__(self, settings: Settings, apps: Apps):
|
||||
self.settings = settings
|
||||
self.apps = apps
|
||||
|
||||
def __call__(self, ctx: ClassDefContext) -> None:
|
||||
base_model_classes.add(ctx.cls.fullname)
|
||||
|
||||
set_fieldname_attrs_for_related_fields(ctx)
|
||||
set_objects_queryset_to_model_class(ctx)
|
||||
|
||||
|
||||
def always_return_none(manager: BuildManager):
|
||||
return None
|
||||
|
||||
|
||||
build.read_plugins_snapshot = always_return_none
|
||||
|
||||
|
||||
class DjangoPlugin(Plugin):
|
||||
def __init__(self,
|
||||
options: Options) -> None:
|
||||
super().__init__(options)
|
||||
self.django_settings = None
|
||||
self.apps = None
|
||||
|
||||
monkeypatch.replace_apply_function_plugin_method()
|
||||
|
||||
django_settings_module = os.environ.get('DJANGO_SETTINGS_MODULE')
|
||||
if django_settings_module:
|
||||
self.django_settings = Settings(django_settings_module)
|
||||
# import django
|
||||
# django.setup()
|
||||
#
|
||||
# from django.apps import apps
|
||||
# self.apps = apps
|
||||
#
|
||||
# models_modules = []
|
||||
# for app_config in self.apps.app_configs.values():
|
||||
# models_modules.append(app_config.module.__name__ + '.' + 'models')
|
||||
#
|
||||
# monkeypatch.state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(django_settings_module,
|
||||
# models_modules)
|
||||
monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(django_settings_module)
|
||||
|
||||
def get_function_hook(self, fullname: str
|
||||
) -> Optional[Callable[[FunctionContext], Type]]:
|
||||
if fullname == helpers.FOREIGN_KEY_FULLNAME:
|
||||
return set_related_name_manager_for_foreign_key
|
||||
return ForeignKeyHook(settings=self.django_settings,
|
||||
apps=self.apps)
|
||||
|
||||
if fullname == helpers.ONETOONE_FIELD_FULLNAME:
|
||||
return set_related_name_instance_for_onetoonefield
|
||||
return OneToOneFieldHook(settings=self.django_settings,
|
||||
apps=self.apps)
|
||||
|
||||
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
|
||||
return determine_type_of_array_field
|
||||
@@ -49,9 +82,11 @@ class DjangoPlugin(Plugin):
|
||||
def get_base_class_hook(self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
if fullname in base_model_classes:
|
||||
return transform_model_class
|
||||
if fullname == 'django.conf._DjangoConfLazyObject':
|
||||
return TransformModelClassHook(self.django_settings, self.apps)
|
||||
|
||||
if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS:
|
||||
return DjangoConfSettingsInitializerHook(settings=self.django_settings)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
112
mypy_django_plugin/monkeypatch.py
Normal file
112
mypy_django_plugin/monkeypatch.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from typing import Optional, List, Sequence
|
||||
|
||||
from mypy.build import BuildManager, Graph, State
|
||||
from mypy.modulefinder import BuildSource
|
||||
from mypy.nodes import Expression, Context
|
||||
from mypy.plugin import FunctionContext, MethodContext
|
||||
from mypy.types import Type, CallableType, Instance
|
||||
|
||||
|
||||
def state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(settings_module: str,
|
||||
models_py_modules: List[str]):
|
||||
from mypy.build import State
|
||||
|
||||
old_compute_dependencies = State.compute_dependencies
|
||||
|
||||
def patched_compute_dependencies(self: State):
|
||||
old_compute_dependencies(self)
|
||||
if self.id == settings_module:
|
||||
self.dependencies.extend(models_py_modules)
|
||||
|
||||
State.compute_dependencies = patched_compute_dependencies
|
||||
|
||||
|
||||
def load_graph_to_add_settings_file_as_a_source_seed(settings_module: str):
|
||||
from mypy import build
|
||||
|
||||
old_load_graph = build.load_graph
|
||||
|
||||
def patched_load_graph(sources: List[BuildSource], manager: BuildManager,
|
||||
old_graph: Optional[Graph] = None,
|
||||
new_modules: Optional[List[State]] = None):
|
||||
if all([source.module != settings_module for source in sources]):
|
||||
sources.append(BuildSource(None, settings_module, None))
|
||||
|
||||
return old_load_graph(sources=sources, manager=manager,
|
||||
old_graph=old_graph,
|
||||
new_modules=new_modules)
|
||||
|
||||
build.load_graph = patched_load_graph
|
||||
|
||||
|
||||
def replace_apply_function_plugin_method():
|
||||
def apply_function_plugin(self,
|
||||
arg_types: List[Type],
|
||||
inferred_ret_type: Type,
|
||||
arg_names: Optional[Sequence[Optional[str]]],
|
||||
formal_to_actual: List[List[int]],
|
||||
args: List[Expression],
|
||||
num_formals: int,
|
||||
fullname: str,
|
||||
object_type: Optional[Type],
|
||||
context: Context) -> Type:
|
||||
"""Use special case logic to infer the return type of a specific named function/method.
|
||||
|
||||
Caller must ensure that a plugin hook exists. There are two different cases:
|
||||
|
||||
- If object_type is None, the caller must ensure that a function hook exists
|
||||
for fullname.
|
||||
- If object_type is not None, the caller must ensure that a method hook exists
|
||||
for fullname.
|
||||
|
||||
Return the inferred return type.
|
||||
"""
|
||||
formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]]
|
||||
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
|
||||
formal_arg_names = [None for _ in range(num_formals)] # type: List[Optional[str]]
|
||||
for formal, actuals in enumerate(formal_to_actual):
|
||||
for actual in actuals:
|
||||
formal_arg_types[formal].append(arg_types[actual])
|
||||
formal_arg_exprs[formal].append(args[actual])
|
||||
if arg_names:
|
||||
formal_arg_names[formal] = arg_names[actual]
|
||||
|
||||
num_passed_positionals = sum([1 if name is None else 0
|
||||
for name in formal_arg_names])
|
||||
if arg_names and num_passed_positionals > 0:
|
||||
object_type_info = None
|
||||
if object_type is not None:
|
||||
if isinstance(object_type, CallableType):
|
||||
# class object, convert to corresponding Instance
|
||||
object_type = object_type.ret_type
|
||||
if isinstance(object_type, Instance):
|
||||
# skip TypedDictType and others
|
||||
object_type_info = object_type.type
|
||||
|
||||
defn_arg_names = self._get_defn_arg_names(fullname, object_type=object_type_info)
|
||||
if defn_arg_names:
|
||||
if num_formals < len(defn_arg_names):
|
||||
# self/cls argument has been passed implicitly
|
||||
defn_arg_names = defn_arg_names[1:]
|
||||
formal_arg_names[:num_passed_positionals] = defn_arg_names[:num_passed_positionals]
|
||||
|
||||
if object_type is None:
|
||||
# Apply function plugin
|
||||
callback = self.plugin.get_function_hook(fullname)
|
||||
assert callback is not None # Assume that caller ensures this
|
||||
return callback(
|
||||
FunctionContext(formal_arg_names, formal_arg_types,
|
||||
inferred_ret_type, formal_arg_exprs,
|
||||
context, self.chk))
|
||||
else:
|
||||
# Apply method plugin
|
||||
method_callback = self.plugin.get_method_hook(fullname)
|
||||
assert method_callback is not None # Assume that caller ensures this
|
||||
return method_callback(
|
||||
MethodContext(object_type, formal_arg_names, formal_arg_types,
|
||||
inferred_ret_type, formal_arg_exprs,
|
||||
context, self.chk))
|
||||
|
||||
from mypy.checkexpr import ExpressionChecker
|
||||
ExpressionChecker.apply_function_plugin = apply_function_plugin
|
||||
|
||||
@@ -9,19 +9,28 @@ from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None:
|
||||
if 'objects' in ctx.cls.info.names:
|
||||
return
|
||||
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||
# search over mro
|
||||
objects_sym = ctx.cls.info.get('objects')
|
||||
if objects_sym is not None:
|
||||
return None
|
||||
|
||||
metaclass_node = ctx.cls.info.names.get('Meta')
|
||||
if metaclass_node is not None:
|
||||
for stmt in metaclass_node.node.defn.defs.body:
|
||||
# only direct Meta class
|
||||
metaclass_sym = ctx.cls.info.names.get('Meta')
|
||||
# skip if abstract
|
||||
if metaclass_sym is not None:
|
||||
for stmt in metaclass_sym.node.defn.defs.body:
|
||||
if (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1
|
||||
and stmt.lvalues[0].name == 'abstract'):
|
||||
is_abstract = api.parse_bool(stmt.rvalue)
|
||||
is_abstract = ctx.api.parse_bool(stmt.rvalue)
|
||||
if is_abstract:
|
||||
return
|
||||
return None
|
||||
|
||||
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, args=[Instance(ctx.cls.info, [])])
|
||||
new_objects_node = helpers.create_new_symtable_node('objects', MDEF, instance=typ)
|
||||
ctx.cls.info.names['objects'] = new_objects_node
|
||||
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME,
|
||||
args=[Instance(ctx.cls.info, [])])
|
||||
if not typ:
|
||||
return None
|
||||
|
||||
ctx.cls.info.names['objects'] = helpers.create_new_symtable_node('objects',
|
||||
kind=MDEF,
|
||||
instance=typ)
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Type
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
|
||||
signature = helpers.get_call_signature_or_none(ctx)
|
||||
if signature is None:
|
||||
if 'base_field' not in ctx.arg_names:
|
||||
return ctx.default_return_type
|
||||
|
||||
_, base_field_arg_type = signature['base_field']
|
||||
base_field_arg_type = ctx.arg_types[ctx.arg_names.index('base_field')][0]
|
||||
return ctx.api.named_generic_type(ctx.context.callee.fullname,
|
||||
args=[base_field_arg_type.type.names['__get__'].type.ret_type])
|
||||
|
||||
@@ -1,23 +1,63 @@
|
||||
from typing import Optional, cast
|
||||
import typing
|
||||
from typing import Optional, cast, Tuple, Any
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.conf import Settings
|
||||
from django.db import models
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, MemberExpr
|
||||
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, StrExpr
|
||||
from mypy.plugin import FunctionContext, ClassDefContext
|
||||
from mypy.types import Type, CallableType, Instance, AnyType
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def extract_to_value_type(ctx: FunctionContext) -> Optional[Instance]:
|
||||
signature = helpers.get_call_signature_or_none(ctx)
|
||||
if signature is None or 'to' not in signature:
|
||||
return None
|
||||
def get_instance_type_for_class(klass: typing.Type[models.Model],
|
||||
api: TypeChecker) -> Optional[Instance]:
|
||||
model_qualname = helpers.get_obj_type_name(klass)
|
||||
module_name, _, class_name = model_qualname.rpartition('.')
|
||||
module = api.modules.get(module_name)
|
||||
if not module or class_name not in module.names:
|
||||
return
|
||||
|
||||
arg, arg_type = signature['to']
|
||||
if not isinstance(arg_type, CallableType):
|
||||
return None
|
||||
sym = module.names[class_name]
|
||||
return Instance(sym.node, [])
|
||||
|
||||
return arg_type.ret_type
|
||||
|
||||
def extract_to_value_type(ctx: FunctionContext,
|
||||
apps: Optional[Apps]) -> Tuple[Optional[Instance], bool]:
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
|
||||
if 'to' not in ctx.arg_names:
|
||||
return None, False
|
||||
arg = ctx.args[ctx.arg_names.index('to')][0]
|
||||
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
|
||||
|
||||
if isinstance(arg_type, CallableType):
|
||||
return arg_type.ret_type, False
|
||||
|
||||
if apps:
|
||||
if isinstance(arg, StrExpr):
|
||||
arg_value = arg.value
|
||||
if '.' not in arg_value:
|
||||
return None, False
|
||||
|
||||
app_label, modelname = arg_value.lower().split('.')
|
||||
try:
|
||||
model_cls = apps.get_model(app_label, modelname)
|
||||
except LookupError:
|
||||
# no model class found
|
||||
return None, False
|
||||
try:
|
||||
instance = get_instance_type_for_class(model_cls, api=api)
|
||||
if not instance:
|
||||
return None, False
|
||||
return instance, True
|
||||
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
return None, False
|
||||
|
||||
|
||||
def extract_related_name_value(ctx: FunctionContext) -> str:
|
||||
@@ -30,43 +70,56 @@ def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instanc
|
||||
instance=new_member_instance)
|
||||
|
||||
|
||||
def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
|
||||
class ForeignKeyHook(object):
|
||||
def __init__(self, settings: Settings, apps: Apps):
|
||||
self.settings = settings
|
||||
self.apps = apps
|
||||
|
||||
def __call__(self, ctx: FunctionContext) -> Type:
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
outer_class_info = api.tscope.classes[-1]
|
||||
|
||||
if 'related_name' not in ctx.context.arg_names:
|
||||
return ctx.default_return_type
|
||||
|
||||
referred_to = extract_to_value_type(ctx)
|
||||
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps)
|
||||
if not referred_to:
|
||||
return ctx.default_return_type
|
||||
|
||||
if 'related_name' in ctx.context.arg_names:
|
||||
related_name = extract_related_name_value(ctx)
|
||||
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
|
||||
args=[Instance(outer_class_info, [])])
|
||||
if isinstance(referred_to, AnyType):
|
||||
# referred_to defined as string, which is unsupported for now
|
||||
return ctx.default_return_type
|
||||
|
||||
add_new_class_member(referred_to.type,
|
||||
related_name, queryset_type)
|
||||
if is_string_based:
|
||||
return referred_to
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def set_related_name_instance_for_onetoonefield(ctx: FunctionContext) -> Type:
|
||||
class OneToOneFieldHook(object):
|
||||
def __init__(self, settings: Optional[Settings], apps: Optional[Apps]):
|
||||
self.settings = settings
|
||||
self.apps = apps
|
||||
|
||||
def __call__(self, ctx: FunctionContext) -> Type:
|
||||
if 'related_name' not in ctx.context.arg_names:
|
||||
return ctx.default_return_type
|
||||
|
||||
referred_to = extract_to_value_type(ctx)
|
||||
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps)
|
||||
if referred_to is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
if 'related_name' in ctx.context.arg_names:
|
||||
related_name = extract_related_name_value(ctx)
|
||||
outer_class_info = ctx.api.tscope.classes[-1]
|
||||
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
add_new_class_member(referred_to.type, related_name,
|
||||
new_member_instance=api.named_type(outer_class_info.fullname()))
|
||||
new_member_instance=Instance(outer_class_info, []))
|
||||
|
||||
if is_string_based:
|
||||
return referred_to
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import cast, Any
|
||||
from typing import cast
|
||||
|
||||
from django.conf import Settings
|
||||
from mypy.nodes import MDEF, TypeInfo, SymbolTable
|
||||
from mypy.nodes import MDEF
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.semanal import SemanticAnalyzerPass2
|
||||
from mypy.types import Instance, AnyType, TypeOfAny
|
||||
@@ -9,24 +9,22 @@ from mypy.types import Instance, AnyType, TypeOfAny
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def get_obj_type_name(value: Any) -> str:
|
||||
return type(value).__module__ + '.' + type(value).__qualname__
|
||||
|
||||
|
||||
class DjangoConfSettingsInitializerHook(object):
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
|
||||
def __call__(self, ctx: ClassDefContext) -> None:
|
||||
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||
if self.settings:
|
||||
for name, value in self.settings.__dict__.items():
|
||||
if name.isupper():
|
||||
if value is None:
|
||||
# TODO: change to Optional[Any] later
|
||||
ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF,
|
||||
instance=api.builtin_type('builtins.object'))
|
||||
continue
|
||||
|
||||
type_fullname = get_obj_type_name(value)
|
||||
type_fullname = helpers.get_obj_type_name(type(value))
|
||||
sym = api.lookup_fully_qualified_or_none(type_fullname)
|
||||
if sym is not None:
|
||||
args = len(sym.node.type_vars) * [AnyType(TypeOfAny.from_omitted_generics)]
|
||||
|
||||
@@ -5,3 +5,4 @@ python_files = test*.py
|
||||
addopts =
|
||||
--tb=native
|
||||
-v
|
||||
-s
|
||||
@@ -147,6 +147,12 @@ class DjangoDataDrivenTestCase(DataDrivenTestCase):
|
||||
self.old_cwd = os.getcwd()
|
||||
|
||||
self.tmpdir = tempfile.TemporaryDirectory(prefix='mypy-test-')
|
||||
tmpdir_root = os.path.join(self.tmpdir.name, 'tmp')
|
||||
|
||||
new_files = []
|
||||
for path, contents in self.files:
|
||||
new_files.append((path, contents.replace('<TMP>', tmpdir_root)))
|
||||
self.files = new_files
|
||||
|
||||
os.chdir(self.tmpdir.name)
|
||||
os.mkdir(test_temp_dir)
|
||||
@@ -179,7 +185,7 @@ class DjangoDataDrivenTestCase(DataDrivenTestCase):
|
||||
self.clean_up.append((True, d))
|
||||
self.clean_up.append((False, path))
|
||||
|
||||
sys.path.insert(0, os.path.join(self.tmpdir.name, 'tmp'))
|
||||
sys.path.insert(0, tmpdir_root)
|
||||
|
||||
def teardown(self):
|
||||
if hasattr(self, 'old_environ'):
|
||||
|
||||
253
test/helpers.py
Normal file
253
test/helpers.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
from typing import List, Callable, Optional, Tuple
|
||||
|
||||
import pytest # type: ignore # no pytest in typeshed
|
||||
|
||||
skip = pytest.mark.skip
|
||||
|
||||
# AssertStringArraysEqual displays special line alignment helper messages if
|
||||
# the first different line has at least this many characters,
|
||||
MIN_LINE_LENGTH_FOR_ALIGNMENT = 5
|
||||
|
||||
|
||||
class TypecheckAssertionError(AssertionError):
|
||||
def __init__(self, error_message: str, lineno: int):
|
||||
self.error_message = error_message
|
||||
self.lineno = lineno
|
||||
|
||||
def first_line(self):
|
||||
return self.__class__.__name__ + '(message="Invalid output")'
|
||||
|
||||
def __str__(self):
|
||||
return self.error_message
|
||||
|
||||
|
||||
def _clean_up(a: List[str]) -> List[str]:
|
||||
"""Remove common directory prefix from all strings in a.
|
||||
|
||||
This uses a naive string replace; it seems to work well enough. Also
|
||||
remove trailing carriage returns.
|
||||
"""
|
||||
res = []
|
||||
for s in a:
|
||||
prefix = os.sep
|
||||
ss = s
|
||||
for p in prefix, prefix.replace(os.sep, '/'):
|
||||
if p != '/' and p != '//' and p != '\\' and p != '\\\\':
|
||||
ss = ss.replace(p, '')
|
||||
# Ignore spaces at end of line.
|
||||
ss = re.sub(' +$', '', ss)
|
||||
res.append(re.sub('\\r$', '', ss))
|
||||
return res
|
||||
|
||||
|
||||
def _num_skipped_prefix_lines(a1: List[str], a2: List[str]) -> int:
|
||||
num_eq = 0
|
||||
while num_eq < min(len(a1), len(a2)) and a1[num_eq] == a2[num_eq]:
|
||||
num_eq += 1
|
||||
return max(0, num_eq - 4)
|
||||
|
||||
|
||||
def _num_skipped_suffix_lines(a1: List[str], a2: List[str]) -> int:
|
||||
num_eq = 0
|
||||
while (num_eq < min(len(a1), len(a2))
|
||||
and a1[-num_eq - 1] == a2[-num_eq - 1]):
|
||||
num_eq += 1
|
||||
return max(0, num_eq - 4)
|
||||
|
||||
|
||||
def _add_aligned_message(s1: str, s2: str, error_message: str) -> str:
|
||||
"""Align s1 and s2 so that the their first difference is highlighted.
|
||||
|
||||
For example, if s1 is 'foobar' and s2 is 'fobar', display the
|
||||
following lines:
|
||||
|
||||
E: foobar
|
||||
A: fobar
|
||||
^
|
||||
|
||||
If s1 and s2 are long, only display a fragment of the strings around the
|
||||
first difference. If s1 is very short, do nothing.
|
||||
"""
|
||||
|
||||
# Seeing what went wrong is trivial even without alignment if the expected
|
||||
# string is very short. In this case do nothing to simplify output.
|
||||
if len(s1) < 4:
|
||||
return error_message
|
||||
|
||||
maxw = 72 # Maximum number of characters shown
|
||||
|
||||
error_message += 'Alignment of first line difference:\n'
|
||||
# sys.stderr.write('Alignment of first line difference:\n')
|
||||
|
||||
trunc = False
|
||||
while s1[:30] == s2[:30]:
|
||||
s1 = s1[10:]
|
||||
s2 = s2[10:]
|
||||
trunc = True
|
||||
|
||||
if trunc:
|
||||
s1 = '...' + s1
|
||||
s2 = '...' + s2
|
||||
|
||||
max_len = max(len(s1), len(s2))
|
||||
extra = ''
|
||||
if max_len > maxw:
|
||||
extra = '...'
|
||||
|
||||
# Write a chunk of both lines, aligned.
|
||||
error_message += ' E: {}{}\n'.format(s1[:maxw], extra)
|
||||
# sys.stderr.write(' E: {}{}\n'.format(s1[:maxw], extra))
|
||||
error_message += ' A: {}{}\n'.format(s2[:maxw], extra)
|
||||
# sys.stderr.write(' A: {}{}\n'.format(s2[:maxw], extra))
|
||||
# Write an indicator character under the different columns.
|
||||
error_message += ' '
|
||||
# sys.stderr.write(' ')
|
||||
for j in range(min(maxw, max(len(s1), len(s2)))):
|
||||
if s1[j:j + 1] != s2[j:j + 1]:
|
||||
error_message += '^'
|
||||
# sys.stderr.write('^') # Difference
|
||||
break
|
||||
else:
|
||||
error_message += ' '
|
||||
# sys.stderr.write(' ') # Equal
|
||||
error_message += '\n'
|
||||
return error_message
|
||||
# sys.stderr.write('\n')
|
||||
|
||||
|
||||
def assert_string_arrays_equal(expected: List[str], actual: List[str]) -> None:
|
||||
"""Assert that two string arrays are equal.
|
||||
|
||||
Display any differences in a human-readable form.
|
||||
"""
|
||||
|
||||
actual = _clean_up(actual)
|
||||
error_message = ''
|
||||
|
||||
if actual != expected:
|
||||
num_skip_start = _num_skipped_prefix_lines(expected, actual)
|
||||
num_skip_end = _num_skipped_suffix_lines(expected, actual)
|
||||
|
||||
error_message += 'Expected:\n'
|
||||
# sys.stderr.write('Expected:\n')
|
||||
|
||||
# If omit some lines at the beginning, indicate it by displaying a line
|
||||
# with '...'.
|
||||
if num_skip_start > 0:
|
||||
error_message += ' ...\n'
|
||||
# sys.stderr.write(' ...\n')
|
||||
|
||||
# Keep track of the first different line.
|
||||
first_diff = -1
|
||||
|
||||
# Display only this many first characters of identical lines.
|
||||
width = 75
|
||||
|
||||
for i in range(num_skip_start, len(expected) - num_skip_end):
|
||||
if i >= len(actual) or expected[i] != actual[i]:
|
||||
if first_diff < 0:
|
||||
first_diff = i
|
||||
error_message += ' {:<45} (diff)'.format(expected[i])
|
||||
# sys.stderr.write(' {:<45} (diff)'.format(expected[i]))
|
||||
else:
|
||||
e = expected[i]
|
||||
error_message += ' ' + e[:width]
|
||||
# sys.stderr.write(' ' + e[:width])
|
||||
if len(e) > width:
|
||||
error_message += '...'
|
||||
# sys.stderr.write('...')
|
||||
error_message += '\n'
|
||||
# sys.stderr.write('\n')
|
||||
if num_skip_end > 0:
|
||||
error_message += ' ...\n'
|
||||
# sys.stderr.write(' ...\n')
|
||||
|
||||
error_message += 'Actual:\n'
|
||||
# sys.stderr.write('Actual:\n')
|
||||
|
||||
if num_skip_start > 0:
|
||||
error_message += ' ...\n'
|
||||
# sys.stderr.write(' ...\n')
|
||||
|
||||
for j in range(num_skip_start, len(actual) - num_skip_end):
|
||||
if j >= len(expected) or expected[j] != actual[j]:
|
||||
error_message += ' {:<45} (diff)'.format(actual[j])
|
||||
# sys.stderr.write(' {:<45} (diff)'.format(actual[j]))
|
||||
else:
|
||||
a = actual[j]
|
||||
error_message += ' ' + a[:width]
|
||||
# sys.stderr.write(' ' + a[:width])
|
||||
if len(a) > width:
|
||||
error_message += '...'
|
||||
# sys.stderr.write('...')
|
||||
error_message += '\n'
|
||||
# sys.stderr.write('\n')
|
||||
if actual == []:
|
||||
error_message += ' (empty)\n'
|
||||
# sys.stderr.write(' (empty)\n')
|
||||
if num_skip_end > 0:
|
||||
error_message += ' ...\n'
|
||||
# sys.stderr.write(' ...\n')
|
||||
|
||||
error_message += '\n'
|
||||
# sys.stderr.write('\n')
|
||||
|
||||
if first_diff >= 0 and first_diff < len(actual) and (
|
||||
len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
|
||||
or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT):
|
||||
# Display message that helps visualize the differences between two
|
||||
# long lines.
|
||||
error_message = _add_aligned_message(expected[first_diff], actual[first_diff],
|
||||
error_message)
|
||||
|
||||
first_failure = expected[first_diff]
|
||||
if first_failure:
|
||||
lineno = int(first_failure.split(' ')[0].strip(':').split(':')[1])
|
||||
raise TypecheckAssertionError(error_message=f'Invalid output: \n{error_message}',
|
||||
lineno=lineno)
|
||||
|
||||
|
||||
def build_output_line(fname: str, lnum: int, severity: str, message: str, col=None) -> str:
|
||||
if col is None:
|
||||
return f'{fname}:{lnum + 1}: {severity}: {message}'
|
||||
else:
|
||||
return f'{fname}:{lnum + 1}:{col}: {severity}: {message}'
|
||||
|
||||
|
||||
def expand_errors(input_lines: List[str], fname: str) -> List[str]:
|
||||
"""Transform comments such as '# E: message' or
|
||||
'# E:3: message' in input.
|
||||
|
||||
The result is lines like 'fnam:line: error: message'.
|
||||
"""
|
||||
output_lines = []
|
||||
for lnum, line in enumerate(input_lines):
|
||||
# The first in the split things isn't a comment
|
||||
for possible_err_comment in line.split(' # ')[1:]:
|
||||
m = re.search(
|
||||
r'^([ENW]):((?P<col>\d+):)? (?P<message>.*)$',
|
||||
possible_err_comment.strip())
|
||||
if m:
|
||||
if m.group(1) == 'E':
|
||||
severity = 'error'
|
||||
elif m.group(1) == 'N':
|
||||
severity = 'note'
|
||||
elif m.group(1) == 'W':
|
||||
severity = 'warning'
|
||||
col = m.group('col')
|
||||
output_lines.append(build_output_line(fname, lnum, severity,
|
||||
message=m.group("message"),
|
||||
col=col))
|
||||
return output_lines
|
||||
|
||||
|
||||
def get_func_first_lnum(attr: Callable[..., None]) -> Optional[Tuple[int, List[str]]]:
|
||||
lines, _ = inspect.getsourcelines(attr)
|
||||
for lnum, line in enumerate(lines):
|
||||
no_space_line = line.strip()
|
||||
if f'def {attr.__name__}' in no_space_line:
|
||||
return lnum, lines[lnum + 1:]
|
||||
raise ValueError(f'No line "def {attr.__name__}" found')
|
||||
299
test/pytest_plugin.py
Normal file
299
test/pytest_plugin.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Any, Optional, cast, List, Type, Callable, Dict
|
||||
|
||||
import pytest
|
||||
from _pytest._code.code import ReprFileLocation, ReprEntry, ExceptionInfo
|
||||
from decorator import decorate
|
||||
from mypy import api as mypy_api
|
||||
|
||||
from test import vistir
|
||||
from test.helpers import assert_string_arrays_equal, TypecheckAssertionError, expand_errors, get_func_first_lnum
|
||||
|
||||
|
||||
def reveal_type(obj: Any) -> None:
|
||||
# noop method, just to get rid of "method is not resolved" errors
|
||||
pass
|
||||
|
||||
|
||||
def output(output_lines: str):
|
||||
def decor(func: Callable[..., None]):
|
||||
func.out = output_lines
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorate(func, wrapper)
|
||||
|
||||
return decor
|
||||
|
||||
|
||||
def get_class_that_defined_method(meth) -> Type['MypyTypecheckTestCase']:
|
||||
if inspect.ismethod(meth):
|
||||
for cls in inspect.getmro(meth.__self__.__class__):
|
||||
if cls.__dict__.get(meth.__name__) is meth:
|
||||
return cls
|
||||
meth = meth.__func__ # fallback to __qualname__ parsing
|
||||
if inspect.isfunction(meth):
|
||||
cls = getattr(inspect.getmodule(meth),
|
||||
meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0])
|
||||
if issubclass(cls, MypyTypecheckTestCase):
|
||||
return cls
|
||||
return getattr(meth, '__objclass__', None) # handle special descriptor objects
|
||||
|
||||
|
||||
def file(filename: str, make_parent_packages=False):
|
||||
def decor(func: Callable[..., None]):
|
||||
func.filename = filename
|
||||
func.make_parent_packages = make_parent_packages
|
||||
return func
|
||||
|
||||
return decor
|
||||
|
||||
|
||||
def env(**environ):
|
||||
def decor(func: Callable[..., None]):
|
||||
func.env = environ
|
||||
return func
|
||||
|
||||
return decor
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CreateFile:
|
||||
sources: str
|
||||
make_parent_packages: bool = False
|
||||
|
||||
|
||||
class MypyTypecheckMeta(type):
|
||||
def __new__(mcs, name, bases, attrs):
|
||||
cls = super().__new__(mcs, name, bases, attrs)
|
||||
cls.files: Dict[str, CreateFile] = {}
|
||||
|
||||
for name, attr in attrs.items():
|
||||
if inspect.isfunction(attr):
|
||||
filename = getattr(attr, 'filename', None)
|
||||
if not filename:
|
||||
continue
|
||||
make_parent_packages = getattr(attr, 'make_parent_packages', False)
|
||||
sources = textwrap.dedent(''.join(get_func_first_lnum(attr)[1]))
|
||||
if sources.strip() == 'pass':
|
||||
sources = ''
|
||||
cls.files[filename] = CreateFile(sources, make_parent_packages)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
class MypyTypecheckTestCase(metaclass=MypyTypecheckMeta):
|
||||
files = None
|
||||
|
||||
def ini_file(self) -> str:
|
||||
return """
|
||||
[mypy]
|
||||
"""
|
||||
|
||||
def _get_ini_file_contents(self) -> Optional[str]:
|
||||
raw_ini_file = self.ini_file()
|
||||
if not raw_ini_file:
|
||||
return raw_ini_file
|
||||
return raw_ini_file.strip() + '\n'
|
||||
|
||||
|
||||
class TraceLastReprEntry(ReprEntry):
|
||||
def toterminal(self, tw):
|
||||
self.reprfileloc.toterminal(tw)
|
||||
for line in self.lines:
|
||||
red = line.startswith("E ")
|
||||
tw.line(line, bold=True, red=red)
|
||||
return
|
||||
|
||||
|
||||
def fname_to_module(fpath: Path, root_path: Path) -> Optional[str]:
|
||||
try:
|
||||
relpath = fpath.relative_to(root_path).with_suffix('')
|
||||
return str(relpath).replace(os.sep, '.')
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class MypyTypecheckItem(pytest.Item):
|
||||
root_directory = '/run/testdata'
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
parent: 'MypyTestsCollector',
|
||||
klass: Type[MypyTypecheckTestCase],
|
||||
source_code: str,
|
||||
first_lineno: int,
|
||||
ini_file_contents: Optional[str] = None,
|
||||
expected_output_lines: Optional[List[str]] = None,
|
||||
files: Optional[Dict[str, CreateFile]] = None,
|
||||
custom_environment: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(name=name, parent=parent)
|
||||
self.klass = klass
|
||||
self.source_code = source_code
|
||||
self.first_lineno = first_lineno
|
||||
self.ini_file_contents = ini_file_contents
|
||||
self.expected_output_lines = expected_output_lines
|
||||
self.files = files
|
||||
self.custom_environment = custom_environment
|
||||
|
||||
@contextmanager
|
||||
def temp_directory(self) -> Path:
|
||||
with tempfile.TemporaryDirectory(prefix='mypy-pytest-',
|
||||
dir=self.root_directory) as tmpdir_name:
|
||||
yield Path(self.root_directory) / tmpdir_name
|
||||
|
||||
def runtest(self):
|
||||
with self.temp_directory() as tmpdir_path:
|
||||
if not self.source_code:
|
||||
return
|
||||
|
||||
if self.ini_file_contents:
|
||||
mypy_ini_fpath = tmpdir_path / 'mypy.ini'
|
||||
mypy_ini_fpath.write_text(self.ini_file_contents)
|
||||
|
||||
test_specific_modules = []
|
||||
for fname, create_file in self.files.items():
|
||||
fpath = tmpdir_path / fname
|
||||
if create_file.make_parent_packages:
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
for parent in fpath.parents:
|
||||
try:
|
||||
parent.relative_to(tmpdir_path)
|
||||
if parent != tmpdir_path:
|
||||
parent_init_file = parent / '__init__.py'
|
||||
parent_init_file.write_text('')
|
||||
test_specific_modules.append(fname_to_module(parent,
|
||||
root_path=tmpdir_path))
|
||||
except ValueError:
|
||||
break
|
||||
|
||||
fpath.write_text(create_file.sources)
|
||||
test_specific_modules.append(fname_to_module(fpath,
|
||||
root_path=tmpdir_path))
|
||||
|
||||
with vistir.temp_environ(), vistir.temp_path():
|
||||
for key, val in (self.custom_environment or {}).items():
|
||||
os.environ[key] = val
|
||||
sys.path.insert(0, str(tmpdir_path))
|
||||
|
||||
mypy_cmd_options = self.prepare_mypy_cmd_options(config_file_path=mypy_ini_fpath)
|
||||
main_fpath = tmpdir_path / 'main.py'
|
||||
main_fpath.write_text(self.source_code)
|
||||
mypy_cmd_options.append(str(main_fpath))
|
||||
|
||||
stdout, _, _ = mypy_api.run(mypy_cmd_options)
|
||||
output_lines = []
|
||||
for line in stdout.splitlines():
|
||||
if ':' not in line:
|
||||
continue
|
||||
out_fpath, res_line = line.split(':', 1)
|
||||
line = os.path.relpath(out_fpath, start=tmpdir_path) + ':' + res_line
|
||||
output_lines.append(line.strip().replace('.py', ''))
|
||||
|
||||
for module in test_specific_modules:
|
||||
if module in sys.modules:
|
||||
del sys.modules[module]
|
||||
raise ValueError
|
||||
assert_string_arrays_equal(expected=self.expected_output_lines,
|
||||
actual=output_lines)
|
||||
|
||||
def prepare_mypy_cmd_options(self, config_file_path: Path) -> List[str]:
|
||||
mypy_cmd_options = [
|
||||
'--show-traceback',
|
||||
'--no-silence-site-packages'
|
||||
]
|
||||
python_version = '.'.join([str(part) for part in sys.version_info[:2]])
|
||||
mypy_cmd_options.append(f'--python-version={python_version}')
|
||||
if self.ini_file_contents:
|
||||
mypy_cmd_options.append(f'--config-file={config_file_path}')
|
||||
return mypy_cmd_options
|
||||
|
||||
def repr_failure(self, excinfo: ExceptionInfo) -> str:
|
||||
if excinfo.errisinstance(SystemExit):
|
||||
# We assume that before doing exit() (which raises SystemExit) we've printed
|
||||
# enough context about what happened so that a stack trace is not useful.
|
||||
# In particular, uncaught exceptions during semantic analysis or type checking
|
||||
# call exit() and they already print out a stack trace.
|
||||
return excinfo.exconly(tryshort=True)
|
||||
elif excinfo.errisinstance(TypecheckAssertionError):
|
||||
# with traceback removed
|
||||
exception_repr = excinfo.getrepr(style='short')
|
||||
exception_repr.reprcrash.message = ''
|
||||
repr_file_location = ReprFileLocation(path=inspect.getfile(self.klass),
|
||||
lineno=self.first_lineno + excinfo.value.lineno,
|
||||
message='')
|
||||
repr_tb_entry = TraceLastReprEntry(filelocrepr=repr_file_location,
|
||||
lines=exception_repr.reprtraceback.reprentries[-1].lines[1:],
|
||||
style='short',
|
||||
reprlocals=None,
|
||||
reprfuncargs=None)
|
||||
exception_repr.reprtraceback.reprentries = [repr_tb_entry]
|
||||
return exception_repr
|
||||
else:
|
||||
return super().repr_failure(excinfo, style='short')
|
||||
|
||||
def reportinfo(self):
|
||||
return self.fspath, None, get_class_qualname(self.klass) + '::' + self.name
|
||||
|
||||
|
||||
def get_class_qualname(klass: type) -> str:
|
||||
return klass.__module__ + '.' + klass.__name__
|
||||
|
||||
|
||||
def extract_test_output(attr: Callable[..., None]) -> List[str]:
|
||||
out_data: str = getattr(attr, 'out', None)
|
||||
out_lines = []
|
||||
if out_data:
|
||||
for line in out_data.split('\n'):
|
||||
line = line.strip()
|
||||
out_lines.append(line)
|
||||
return out_lines
|
||||
|
||||
|
||||
class MypyTestsCollector(pytest.Class):
|
||||
def get_ini_file_contents(self, contents: str) -> str:
|
||||
return contents.strip() + '\n'
|
||||
|
||||
def collect(self) -> Iterator[pytest.Item]:
|
||||
current_testcase = cast(MypyTypecheckTestCase, self.obj())
|
||||
ini_file_contents = self.get_ini_file_contents(current_testcase.ini_file())
|
||||
for attr_name in dir(current_testcase):
|
||||
if attr_name.startswith('_test_'):
|
||||
attr = getattr(self.obj, attr_name)
|
||||
if inspect.isfunction(attr):
|
||||
first_line_lnum, source_lines = get_func_first_lnum(attr)
|
||||
func_first_line_in_file = inspect.getsourcelines(attr)[1] + first_line_lnum
|
||||
|
||||
output_from_decorator = extract_test_output(attr)
|
||||
output_from_comments = expand_errors(source_lines, 'main')
|
||||
custom_env = getattr(attr, 'env', None)
|
||||
main_source_code = textwrap.dedent(''.join(source_lines))
|
||||
yield MypyTypecheckItem(name=attr_name,
|
||||
parent=self,
|
||||
klass=current_testcase.__class__,
|
||||
source_code=main_source_code,
|
||||
first_lineno=func_first_line_in_file,
|
||||
ini_file_contents=ini_file_contents,
|
||||
expected_output_lines=output_from_comments
|
||||
+ output_from_decorator,
|
||||
files=current_testcase.__class__.files,
|
||||
custom_environment=custom_env)
|
||||
|
||||
|
||||
def pytest_pycollect_makeitem(collector: Any, name: str, obj: Any) -> Optional[MypyTestsCollector]:
|
||||
# Only classes derived from DataSuite contain test cases, not the DataSuite class itself
|
||||
if (isinstance(obj, type)
|
||||
and issubclass(obj, MypyTypecheckTestCase)
|
||||
and obj is not MypyTypecheckTestCase):
|
||||
# Non-None result means this obj is a test case.
|
||||
# The collect method of the returned DataSuiteCollector instance will be called later,
|
||||
# with self.obj being obj.
|
||||
return MypyTestsCollector(name, parent=collector)
|
||||
0
test/pytest_tests/__init__.py
Normal file
0
test/pytest_tests/__init__.py
Normal file
9
test/pytest_tests/base.py
Normal file
9
test/pytest_tests/base.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from test.pytest_plugin import MypyTypecheckTestCase
|
||||
|
||||
|
||||
class BaseDjangoPluginTestCase(MypyTypecheckTestCase):
|
||||
def ini_file(self):
|
||||
return """
|
||||
[mypy]
|
||||
plugins = mypy_django_plugin.main
|
||||
"""
|
||||
21
test/pytest_tests/test_model_fields.py
Normal file
21
test/pytest_tests/test_model_fields.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from test.pytest_plugin import reveal_type
|
||||
from test.pytest_tests.base import BaseDjangoPluginTestCase
|
||||
|
||||
|
||||
class TestBasicModelFields(BaseDjangoPluginTestCase):
|
||||
def test_model_field_classes_present_as_primitives(self):
|
||||
from django.db import models
|
||||
|
||||
class User(models.Model):
|
||||
id = models.AutoField(primary_key=True)
|
||||
small_int = models.SmallIntegerField()
|
||||
name = models.CharField(max_length=255)
|
||||
slug = models.SlugField(max_length=255)
|
||||
text = models.TextField()
|
||||
|
||||
user = User()
|
||||
reveal_type(user.id) # E: Revealed type is 'builtins.int'
|
||||
reveal_type(user.small_int) # E: Revealed type is 'builtins.int'
|
||||
reveal_type(user.name) # E: Revealed type is 'builtins.str'
|
||||
reveal_type(user.slug) # E: Revealed type is 'builtins.str'
|
||||
reveal_type(user.text) # E: Revealed type is 'builtins.str'
|
||||
@@ -1,16 +1,9 @@
|
||||
from test.pytest_plugin import MypyTypecheckTestCase, reveal_type
|
||||
from test.pytest_plugin import reveal_type
|
||||
from test.pytest_tests.base import BaseDjangoPluginTestCase
|
||||
|
||||
|
||||
class BaseDjangoPluginTestCase(MypyTypecheckTestCase):
|
||||
def ini_file(self):
|
||||
return """
|
||||
[mypy]
|
||||
plugins = mypy_django_plugin.main
|
||||
"""
|
||||
|
||||
|
||||
class MyTestCase(BaseDjangoPluginTestCase):
|
||||
def check_foreign_key_field(self):
|
||||
class TestForeignKey(BaseDjangoPluginTestCase):
|
||||
def test_foreign_key_field(self):
|
||||
from django.db import models
|
||||
|
||||
class Publisher(models.Model):
|
||||
@@ -26,7 +19,7 @@ class MyTestCase(BaseDjangoPluginTestCase):
|
||||
publisher = Publisher()
|
||||
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
|
||||
|
||||
def check_every_foreign_key_creates_field_name_with_appended_id(self):
|
||||
def test_every_foreign_key_creates_field_name_with_appended_id(self):
|
||||
from django.db import models
|
||||
|
||||
class Publisher(models.Model):
|
||||
@@ -39,7 +32,7 @@ class MyTestCase(BaseDjangoPluginTestCase):
|
||||
book = Book()
|
||||
reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
|
||||
|
||||
def check_foreign_key_different_order_of_params(self):
|
||||
def test_foreign_key_different_order_of_params(self):
|
||||
from django.db import models
|
||||
|
||||
class Publisher(models.Model):
|
||||
@@ -54,3 +47,43 @@ class MyTestCase(BaseDjangoPluginTestCase):
|
||||
|
||||
publisher = Publisher()
|
||||
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
|
||||
|
||||
|
||||
class TestOneToOneField(BaseDjangoPluginTestCase):
|
||||
def test_onetoone_field(self):
|
||||
from django.db import models
|
||||
|
||||
class User(models.Model):
|
||||
pass
|
||||
|
||||
class Profile(models.Model):
|
||||
user = models.OneToOneField(to=User, on_delete=models.CASCADE, related_name='profile')
|
||||
|
||||
profile = Profile()
|
||||
reveal_type(profile.user) # E: Revealed type is 'main.User*'
|
||||
|
||||
user = User()
|
||||
reveal_type(user.profile) # E: Revealed type is 'main.Profile'
|
||||
|
||||
def test_onetoone_field_with_underscore_id(self):
|
||||
from django.db import models
|
||||
|
||||
class User(models.Model):
|
||||
pass
|
||||
|
||||
class Profile(models.Model):
|
||||
user = models.OneToOneField(to=User, on_delete=models.CASCADE, related_name='profile')
|
||||
|
||||
profile = Profile()
|
||||
reveal_type(profile.user_id) # E: Revealed type is 'builtins.int'
|
||||
|
||||
def test_parameter_to_keyword_may_be_absent(self):
|
||||
from django.db import models
|
||||
|
||||
class User(models.Model):
|
||||
pass
|
||||
|
||||
class Profile(models.Model):
|
||||
user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile')
|
||||
|
||||
reveal_type(User().profile) # E: Revealed type is 'main.Profile'
|
||||
|
||||
28
test/pytest_tests/test_objects_queryset.py
Normal file
28
test/pytest_tests/test_objects_queryset.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from test.pytest_plugin import reveal_type, output
|
||||
from test.pytest_tests.base import BaseDjangoPluginTestCase
|
||||
|
||||
|
||||
class TestObjectsQueryset(BaseDjangoPluginTestCase):
|
||||
def test_every_model_has_objects_queryset_available(self):
|
||||
from django.db import models
|
||||
|
||||
class User(models.Model):
|
||||
pass
|
||||
|
||||
reveal_type(User.objects) # E: Revealed type is 'django.db.models.query.QuerySet[main.User]'
|
||||
|
||||
@output("""
|
||||
main:10: error: Revealed type is 'Any'
|
||||
main:10: error: "Type[ModelMixin]" has no attribute "objects"
|
||||
""")
|
||||
def test_objects_get_returns_model_instance(self):
|
||||
from django.db import models
|
||||
|
||||
class ModelMixin(models.Model):
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
class User(ModelMixin):
|
||||
pass
|
||||
|
||||
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'
|
||||
37
test/pytest_tests/test_parse_settings.py
Normal file
37
test/pytest_tests/test_parse_settings.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from test.pytest_plugin import reveal_type, file, env
|
||||
from test.pytest_tests.base import BaseDjangoPluginTestCase
|
||||
|
||||
|
||||
class TestParseSettingsFromFile(BaseDjangoPluginTestCase):
|
||||
@env(DJANGO_SETTINGS_MODULE='mysettings')
|
||||
def test_case(self):
|
||||
from django.conf import settings
|
||||
|
||||
reveal_type(settings.ROOT_DIR) # E: Revealed type is 'builtins.str'
|
||||
reveal_type(settings.OBJ) # E: Revealed type is 'django.utils.functional.LazyObject'
|
||||
reveal_type(settings.NUMBERS) # E: Revealed type is 'builtins.list[Any]'
|
||||
reveal_type(settings.DICT) # E: Revealed type is 'builtins.dict[Any, Any]'
|
||||
|
||||
@file('mysettings.py')
|
||||
def mysettings_py_file(self):
|
||||
SECRET_KEY = 112233
|
||||
ROOT_DIR = '/etc'
|
||||
NUMBERS = ['one', 'two']
|
||||
DICT = {} # type: ignore
|
||||
|
||||
from django.utils.functional import LazyObject
|
||||
|
||||
OBJ = LazyObject()
|
||||
|
||||
|
||||
class TestSettingInitializableToNone(BaseDjangoPluginTestCase):
|
||||
@env(DJANGO_SETTINGS_MODULE='mysettings')
|
||||
def test_case(self):
|
||||
from django.conf import settings
|
||||
|
||||
reveal_type(settings.NONE_SETTING) # E: Revealed type is 'builtins.object'
|
||||
|
||||
@file('mysettings.py')
|
||||
def mysettings_py_file(self):
|
||||
SECRET_KEY = 112233
|
||||
NONE_SETTING = None
|
||||
27
test/pytest_tests/test_postgres_fields.py
Normal file
27
test/pytest_tests/test_postgres_fields.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from test.pytest_plugin import reveal_type
|
||||
from test.pytest_tests.base import BaseDjangoPluginTestCase
|
||||
|
||||
|
||||
class TestArrayField(BaseDjangoPluginTestCase):
|
||||
def test_descriptor_access(self):
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
||||
class User(models.Model):
|
||||
array = ArrayField(base_field=models.Field())
|
||||
|
||||
user = User()
|
||||
reveal_type(user.array) # E: Revealed type is 'builtins.list[Any]'
|
||||
|
||||
def test_base_field_parsed_into_generic_attribute(self):
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
||||
class User(models.Model):
|
||||
members = ArrayField(base_field=models.IntegerField())
|
||||
members_as_text = ArrayField(base_field=models.CharField(max_length=255))
|
||||
|
||||
user = User()
|
||||
reveal_type(user.members) # E: Revealed type is 'builtins.list[builtins.int*]'
|
||||
reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list[builtins.str*]'
|
||||
|
||||
74
test/pytest_tests/test_to_attr_as_string.py
Normal file
74
test/pytest_tests/test_to_attr_as_string.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from test.pytest_plugin import file, reveal_type, env
|
||||
from test.pytest_tests.base import BaseDjangoPluginTestCase
|
||||
|
||||
|
||||
class TestForeignKey(BaseDjangoPluginTestCase):
|
||||
@env(DJANGO_SETTINGS_MODULE='mysettings')
|
||||
def _test_to_parameter_could_be_specified_as_string(self):
|
||||
from apps.myapp.models import Publisher
|
||||
|
||||
publisher = Publisher()
|
||||
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[apps.myapp2.models.Book]'
|
||||
|
||||
# @env(DJANGO_SETTINGS_MODULE='mysettings')
|
||||
# def _test_creates_underscore_id_attr(self):
|
||||
# from apps.myapp2.models import Book
|
||||
#
|
||||
# book = Book()
|
||||
# reveal_type(book.publisher) # E: Revealed type is 'apps.myapp.models.Publisher'
|
||||
# reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
|
||||
|
||||
@file('mysettings.py')
|
||||
def mysettings(self):
|
||||
SECRET_KEY = '112233'
|
||||
ROOT_DIR = '<TMP>'
|
||||
APPS_DIR = '<TMP>/apps'
|
||||
|
||||
INSTALLED_APPS = ('apps.myapp', 'apps.myapp2')
|
||||
|
||||
@file('apps/myapp/models.py', make_parent_packages=True)
|
||||
def apps_myapp_models(self):
|
||||
from django.db import models
|
||||
|
||||
class Publisher(models.Model):
|
||||
pass
|
||||
|
||||
@file('apps/myapp2/models.py', make_parent_packages=True)
|
||||
def apps_myapp2_models(self):
|
||||
from django.db import models
|
||||
|
||||
class Book(models.Model):
|
||||
publisher = models.ForeignKey(to='myapp.Publisher', on_delete=models.CASCADE,
|
||||
related_name='books')
|
||||
|
||||
|
||||
class TestOneToOneField(BaseDjangoPluginTestCase):
|
||||
@env(DJANGO_SETTINGS_MODULE='mysettings')
|
||||
def test_to_parameter_could_be_specified_as_string(self):
|
||||
from apps.myapp.models import User
|
||||
|
||||
user = User()
|
||||
reveal_type(user.profile) # E: Revealed type is 'apps.myapp2.models.Profile'
|
||||
|
||||
@file('mysettings.py')
|
||||
def mysettings(self):
|
||||
SECRET_KEY = '112233'
|
||||
ROOT_DIR = '<TMP>'
|
||||
APPS_DIR = '<TMP>/apps'
|
||||
|
||||
INSTALLED_APPS = ('apps.myapp', 'apps.myapp2')
|
||||
|
||||
@file('apps/myapp/models.py', make_parent_packages=True)
|
||||
def apps_myapp_models(self):
|
||||
from django.db import models
|
||||
|
||||
class User(models.Model):
|
||||
pass
|
||||
|
||||
@file('apps/myapp2/models.py', make_parent_packages=True)
|
||||
def apps_myapp2_models(self):
|
||||
from django.db import models
|
||||
|
||||
class Profile(models.Model):
|
||||
user = models.OneToOneField(to='myapp.User', on_delete=models.CASCADE,
|
||||
related_name='profile')
|
||||
@@ -14,11 +14,14 @@ MYPY_INI_PATH = ROOT_DIR / 'test' / 'plugins.ini'
|
||||
|
||||
class DjangoTestSuite(DataSuite):
|
||||
files = [
|
||||
'check-objects-queryset.test',
|
||||
'check-model-fields.test',
|
||||
'check-postgres-fields.test',
|
||||
'check-model-relations.test',
|
||||
'check-parse-settings.test'
|
||||
# 'check-objects-queryset.test',
|
||||
# 'check-model-fields.test',
|
||||
# 'check-postgres-fields.test',
|
||||
# 'check-model-relations.test',
|
||||
# 'check-parse-settings.test',
|
||||
# 'check-to-attr-as-string-one-to-one-field.test',
|
||||
'check-to-attr-as-string-foreign-key.test',
|
||||
# 'check-foreign-key-as-string-creates-underscore-id-attr.test'
|
||||
]
|
||||
data_prefix = str(TEST_DATA_DIR)
|
||||
|
||||
|
||||
43
test/vistir.py
Normal file
43
test/vistir.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Borrowed from Pew.
|
||||
# See https://github.com/berdario/pew/blob/master/pew/_utils.py#L82
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from decorator import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temp_environ():
|
||||
"""Allow the ability to set os.environ temporarily"""
|
||||
environ = dict(os.environ)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.environ.clear()
|
||||
os.environ.update(environ)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temp_path():
|
||||
"""A context manager which allows the ability to set sys.path temporarily"""
|
||||
path = [p for p in sys.path]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.path = [p for p in path]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cd(path):
|
||||
"""Context manager to temporarily change working directories"""
|
||||
if not path:
|
||||
return
|
||||
prev_cwd = Path.cwd().as_posix()
|
||||
if isinstance(path, Path):
|
||||
path = path.as_posix()
|
||||
os.chdir(str(path))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(prev_cwd)
|
||||
Reference in New Issue
Block a user