mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-08 21:14:49 +08:00
latest changes
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,3 +4,4 @@ out/
|
|||||||
/test_sqlite.py
|
/test_sqlite.py
|
||||||
/django
|
/django
|
||||||
.idea/
|
.idea/
|
||||||
|
.mypy_cache/
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
pytest_plugins = [
|
pytest_plugins = [
|
||||||
'test.data'
|
'test.pytest_plugin'
|
||||||
]
|
]
|
||||||
7
django-stubs/apps/__init__.pyi
Normal file
7
django-stubs/apps/__init__.pyi
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from .config import (
|
||||||
|
AppConfig as AppConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
from .registry import (
|
||||||
|
apps as apps
|
||||||
|
)
|
||||||
26
django-stubs/apps/config.pyi
Normal file
26
django-stubs/apps/config.pyi
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from typing import Any, Iterator, Type
|
||||||
|
|
||||||
|
from django.db.models.base import Model
|
||||||
|
|
||||||
|
MODELS_MODULE_NAME: str
|
||||||
|
|
||||||
|
class AppConfig:
|
||||||
|
name: str = ...
|
||||||
|
module: Any = ...
|
||||||
|
apps: None = ...
|
||||||
|
label: str = ...
|
||||||
|
verbose_name: str = ...
|
||||||
|
path: str = ...
|
||||||
|
models_module: None = ...
|
||||||
|
models: None = ...
|
||||||
|
def __init__(self, app_name: str, app_module: None) -> None: ...
|
||||||
|
@classmethod
|
||||||
|
def create(cls, entry: str) -> AppConfig: ...
|
||||||
|
def get_model(
|
||||||
|
self, model_name: str, require_ready: bool = ...
|
||||||
|
) -> Type[Model]: ...
|
||||||
|
def get_models(
|
||||||
|
self, include_auto_created: bool = ..., include_swapped: bool = ...
|
||||||
|
) -> Iterator[Type[Model]]: ...
|
||||||
|
def import_models(self) -> None: ...
|
||||||
|
def ready(self) -> None: ...
|
||||||
62
django-stubs/apps/registry.pyi
Normal file
62
django-stubs/apps/registry.pyi
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import collections
|
||||||
|
from typing import Any, Callable, List, Optional, Tuple, Type, Union, Iterable
|
||||||
|
|
||||||
|
from django.apps.config import AppConfig
|
||||||
|
from django.db.migrations.state import AppConfigStub
|
||||||
|
from django.db.models.base import Model
|
||||||
|
|
||||||
|
from .config import AppConfig
|
||||||
|
|
||||||
|
|
||||||
|
class Apps:
|
||||||
|
all_models: collections.defaultdict = ...
|
||||||
|
app_configs: collections.OrderedDict = ...
|
||||||
|
stored_app_configs: List[Any] = ...
|
||||||
|
apps_ready: bool = ...
|
||||||
|
loading: bool = ...
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
installed_apps: Optional[
|
||||||
|
Union[List[AppConfigStub], List[str], Tuple]
|
||||||
|
] = ...,
|
||||||
|
) -> None: ...
|
||||||
|
models_ready: bool = ...
|
||||||
|
ready: bool = ...
|
||||||
|
def populate(
|
||||||
|
self, installed_apps: Union[List[AppConfigStub], List[str], Tuple] = ...
|
||||||
|
) -> None: ...
|
||||||
|
def check_apps_ready(self) -> None: ...
|
||||||
|
def check_models_ready(self) -> None: ...
|
||||||
|
def get_app_configs(self) -> Iterable[AppConfig]: ...
|
||||||
|
def get_app_config(self, app_label: str) -> AppConfig: ...
|
||||||
|
def get_models(
|
||||||
|
self, include_auto_created: bool = ..., include_swapped: bool = ...
|
||||||
|
) -> List[Type[Model]]: ...
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
app_label: str,
|
||||||
|
model_name: Optional[str] = ...,
|
||||||
|
require_ready: bool = ...,
|
||||||
|
) -> Type[Model]: ...
|
||||||
|
def register_model(self, app_label: str, model: Type[Model]) -> None: ...
|
||||||
|
def is_installed(self, app_name: str) -> bool: ...
|
||||||
|
def get_containing_app_config(
|
||||||
|
self, object_name: str
|
||||||
|
) -> Optional[AppConfig]: ...
|
||||||
|
def get_registered_model(
|
||||||
|
self, app_label: str, model_name: str
|
||||||
|
) -> Type[Model]: ...
|
||||||
|
def get_swappable_settings_name(self, to_string: str) -> Optional[str]: ...
|
||||||
|
def set_available_apps(self, available: List[str]) -> None: ...
|
||||||
|
def unset_available_apps(self) -> None: ...
|
||||||
|
def set_installed_apps(
|
||||||
|
self, installed: Union[List[str], Tuple[str]]
|
||||||
|
) -> None: ...
|
||||||
|
def unset_installed_apps(self) -> None: ...
|
||||||
|
def clear_cache(self) -> None: ...
|
||||||
|
def lazy_model_operation(
|
||||||
|
self, function: Callable, *model_keys: Any
|
||||||
|
) -> None: ...
|
||||||
|
def do_pending_operations(self, model: Type[Model]) -> None: ...
|
||||||
|
|
||||||
|
apps: Apps
|
||||||
@@ -1,2 +1,18 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from .utils import (ProgrammingError as ProgrammingError,
|
from .utils import (ProgrammingError as ProgrammingError,
|
||||||
IntegrityError as IntegrityError)
|
IntegrityError as IntegrityError,
|
||||||
|
OperationalError as OperationalError,
|
||||||
|
DatabaseError as DatabaseError,
|
||||||
|
DataError as DataError,
|
||||||
|
NotSupportedError as NotSupportedError)
|
||||||
|
|
||||||
|
connections: Any
|
||||||
|
router: Any
|
||||||
|
|
||||||
|
class DefaultConnectionProxy:
|
||||||
|
def __getattr__(self, item: str) -> Any: ...
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None: ...
|
||||||
|
def __delattr__(self, name: str) -> None: ...
|
||||||
|
|
||||||
|
connection: Any
|
||||||
|
|||||||
@@ -3,14 +3,21 @@ from .base import Model as Model
|
|||||||
from .fields import (AutoField as AutoField,
|
from .fields import (AutoField as AutoField,
|
||||||
IntegerField as IntegerField,
|
IntegerField as IntegerField,
|
||||||
SmallIntegerField as SmallIntegerField,
|
SmallIntegerField as SmallIntegerField,
|
||||||
|
BigIntegerField as BigIntegerField,
|
||||||
CharField as CharField,
|
CharField as CharField,
|
||||||
Field as Field,
|
Field as Field,
|
||||||
SlugField as SlugField,
|
SlugField as SlugField,
|
||||||
TextField as TextField,
|
TextField as TextField,
|
||||||
BooleanField as BooleanField)
|
BooleanField as BooleanField,
|
||||||
|
FileField as FileField,
|
||||||
|
DateField as DateField,
|
||||||
|
DateTimeField as DateTimeField,
|
||||||
|
IPAddressField as IPAddressField,
|
||||||
|
GenericIPAddressField as GenericIPAddressField)
|
||||||
|
|
||||||
from .fields.related import (ForeignKey as ForeignKey,
|
from .fields.related import (ForeignKey as ForeignKey,
|
||||||
OneToOneField as OneToOneField)
|
OneToOneField as OneToOneField,
|
||||||
|
ManyToManyField as ManyToManyField)
|
||||||
|
|
||||||
from .deletion import (CASCADE as CASCADE,
|
from .deletion import (CASCADE as CASCADE,
|
||||||
SET_DEFAULT as SET_DEFAULT,
|
SET_DEFAULT as SET_DEFAULT,
|
||||||
@@ -24,4 +31,11 @@ from .query_utils import Q as Q
|
|||||||
|
|
||||||
from .lookups import Lookup as Lookup
|
from .lookups import Lookup as Lookup
|
||||||
|
|
||||||
from .expressions import F as F
|
from .expressions import (F as F,
|
||||||
|
Subquery as Subquery,
|
||||||
|
Exists as Exists,
|
||||||
|
OrderBy as OrderBy,
|
||||||
|
OuterRef as OuterRef)
|
||||||
|
|
||||||
|
from .manager import (BaseManager as BaseManager,
|
||||||
|
Manager as Manager)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from datetime import datetime, timedelta
|
|||||||
from typing import (Any, Callable, Dict, Iterator, List, Optional, Set, Tuple,
|
from typing import (Any, Callable, Dict, Iterator, List, Optional, Set, Tuple,
|
||||||
Type, Union)
|
Type, Union)
|
||||||
|
|
||||||
|
from django.db.models import QuerySet
|
||||||
from django.db.models.fields import Field
|
from django.db.models.fields import Field
|
||||||
from django.db.models.lookups import Lookup
|
from django.db.models.lookups import Lookup
|
||||||
from django.db.models.sql.compiler import SQLCompiler
|
from django.db.models.sql.compiler import SQLCompiler
|
||||||
@@ -180,4 +181,50 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
|
|||||||
|
|
||||||
|
|
||||||
class F(Combinable):
|
class F(Combinable):
|
||||||
|
name: str
|
||||||
def __init__(self, name: str): ...
|
def __init__(self, name: str): ...
|
||||||
|
def resolve_expression(
|
||||||
|
self,
|
||||||
|
query: Any = ...,
|
||||||
|
allow_joins: bool = ...,
|
||||||
|
reuse: Optional[Set[str]] = ...,
|
||||||
|
summarize: bool = ...,
|
||||||
|
for_save: bool = ...,
|
||||||
|
) -> Expression: ...
|
||||||
|
|
||||||
|
|
||||||
|
class OuterRef(F): ...
|
||||||
|
|
||||||
|
class Subquery(Expression):
|
||||||
|
template: str = ...
|
||||||
|
queryset: QuerySet = ...
|
||||||
|
extra: Dict[Any, Any] = ...
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
queryset: QuerySet,
|
||||||
|
output_field: Optional[Field] = ...,
|
||||||
|
**extra: Any
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
class Exists(Subquery):
|
||||||
|
extra: Dict[Any, Any]
|
||||||
|
template: str = ...
|
||||||
|
negated: bool = ...
|
||||||
|
def __init__(
|
||||||
|
self, *args: Any, negated: bool = ..., **kwargs: Any
|
||||||
|
) -> None: ...
|
||||||
|
def __invert__(self) -> Exists: ...
|
||||||
|
|
||||||
|
class OrderBy(BaseExpression):
|
||||||
|
template: str = ...
|
||||||
|
nulls_first: bool = ...
|
||||||
|
nulls_last: bool = ...
|
||||||
|
descending: bool = ...
|
||||||
|
expression: Expression = ...
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
expression: Combinable,
|
||||||
|
descending: bool = ...,
|
||||||
|
nulls_first: bool = ...,
|
||||||
|
nulls_last: bool = ...,
|
||||||
|
) -> None: ...
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from django.db.models.query_utils import RegisterLookupMixin
|
from django.db.models.query_utils import RegisterLookupMixin
|
||||||
|
|
||||||
@@ -7,6 +7,7 @@ class Field(RegisterLookupMixin):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
primary_key: bool = False,
|
primary_key: bool = False,
|
||||||
**kwargs): ...
|
**kwargs): ...
|
||||||
|
|
||||||
def __get__(self, instance, owner) -> Any: ...
|
def __get__(self, instance, owner) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
@@ -14,8 +15,10 @@ class IntegerField(Field):
|
|||||||
def __get__(self, instance, owner) -> int: ...
|
def __get__(self, instance, owner) -> int: ...
|
||||||
|
|
||||||
|
|
||||||
class SmallIntegerField(IntegerField):
|
class SmallIntegerField(IntegerField): ...
|
||||||
pass
|
|
||||||
|
|
||||||
|
class BigIntegerField(IntegerField): ...
|
||||||
|
|
||||||
|
|
||||||
class AutoField(Field):
|
class AutoField(Field):
|
||||||
@@ -26,11 +29,11 @@ class CharField(Field):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
**kwargs): ...
|
**kwargs): ...
|
||||||
|
|
||||||
def __get__(self, instance, owner) -> str: ...
|
def __get__(self, instance, owner) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
class SlugField(CharField):
|
class SlugField(CharField): ...
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TextField(Field):
|
class TextField(Field):
|
||||||
@@ -39,3 +42,29 @@ class TextField(Field):
|
|||||||
|
|
||||||
class BooleanField(Field):
|
class BooleanField(Field):
|
||||||
def __get__(self, instance, owner) -> bool: ...
|
def __get__(self, instance, owner) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
class FileField(Field): ...
|
||||||
|
|
||||||
|
|
||||||
|
class IPAddressField(Field): ...
|
||||||
|
|
||||||
|
|
||||||
|
class GenericIPAddressField(Field):
|
||||||
|
default_error_messages: Any = ...
|
||||||
|
unpack_ipv4: Any = ...
|
||||||
|
protocol: Any = ...
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
verbose_name: Optional[Any] = ...,
|
||||||
|
name: Optional[Any] = ...,
|
||||||
|
protocol: str = ...,
|
||||||
|
unpack_ipv4: bool = ...,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
class DateField(Field): ...
|
||||||
|
|
||||||
|
class DateTimeField(DateField): ...
|
||||||
@@ -22,3 +22,12 @@ class OneToOneField(Field, Generic[_T]):
|
|||||||
related_name: str = ...,
|
related_name: str = ...,
|
||||||
**kwargs): ...
|
**kwargs): ...
|
||||||
def __get__(self, instance, owner) -> _T: ...
|
def __get__(self, instance, owner) -> _T: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ManyToManyField(Field, Generic[_T]):
|
||||||
|
def __init__(self,
|
||||||
|
to: Union[Type[_T], str],
|
||||||
|
on_delete: Any,
|
||||||
|
related_name: str = ...,
|
||||||
|
**kwargs): ...
|
||||||
|
def __get__(self, instance, owner) -> _T: ...
|
||||||
|
|||||||
140
django-stubs/db/models/manager.pyi
Normal file
140
django-stubs/db/models/manager.pyi
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
from datetime import datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union, TypeVar, Set, Generic, Iterator
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from django.db.models import Q
|
||||||
|
from django.db.models.base import Model
|
||||||
|
from django.db.models.query import QuerySet, RawQuerySet
|
||||||
|
|
||||||
|
_T = TypeVar('_T', bound=Model)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseManager:
|
||||||
|
creation_counter: int = ...
|
||||||
|
auto_created: bool = ...
|
||||||
|
use_in_migrations: bool = ...
|
||||||
|
def __new__(cls: Type[BaseManager], *args: Any, **kwargs: Any) -> BaseManager: ...
|
||||||
|
model: Any = ...
|
||||||
|
name: Any = ...
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
def deconstruct(self) -> Tuple[bool, str, None, Tuple, Dict[str, int]]: ...
|
||||||
|
def check(self, **kwargs: Any) -> List[Any]: ...
|
||||||
|
@classmethod
|
||||||
|
def from_queryset(
|
||||||
|
cls, queryset_class: Any, class_name: Optional[Any] = ...
|
||||||
|
): ...
|
||||||
|
def contribute_to_class(self, model: Type[Model], name: str) -> None: ...
|
||||||
|
def db_manager(
|
||||||
|
self,
|
||||||
|
using: Optional[str] = ...,
|
||||||
|
hints: Optional[Dict[str, Model]] = ...,
|
||||||
|
) -> Manager: ...
|
||||||
|
@property
|
||||||
|
def db(self) -> str: ...
|
||||||
|
def get_queryset(self) -> QuerySet: ...
|
||||||
|
def all(self) -> QuerySet: ...
|
||||||
|
def __eq__(self, other: Optional[Any]) -> bool: ...
|
||||||
|
def __hash__(self): ...
|
||||||
|
|
||||||
|
class Manager(Generic[_T]):
|
||||||
|
def exists(self) -> bool: ...
|
||||||
|
def explain(
|
||||||
|
self, *, format: Optional[Any] = ..., **options: Any
|
||||||
|
) -> str: ...
|
||||||
|
def raw(
|
||||||
|
self,
|
||||||
|
raw_query: str,
|
||||||
|
params: Optional[
|
||||||
|
Union[
|
||||||
|
Dict[str, str],
|
||||||
|
List[datetime],
|
||||||
|
List[Decimal],
|
||||||
|
List[str],
|
||||||
|
Set[str],
|
||||||
|
Tuple[int],
|
||||||
|
]
|
||||||
|
] = ...,
|
||||||
|
translations: Optional[Dict[str, str]] = ...,
|
||||||
|
using: None = ...,
|
||||||
|
) -> RawQuerySet: ...
|
||||||
|
def values(self, *fields: Any, **expressions: Any) -> QuerySet: ...
|
||||||
|
def values_list(
|
||||||
|
self, *fields: Any, flat: bool = ..., named: bool = ...
|
||||||
|
) -> QuerySet: ...
|
||||||
|
def dates(
|
||||||
|
self, field_name: str, kind: str, order: str = ...
|
||||||
|
) -> QuerySet: ...
|
||||||
|
def datetimes(
|
||||||
|
self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...
|
||||||
|
) -> QuerySet: ...
|
||||||
|
def none(self) -> QuerySet[_T]: ...
|
||||||
|
def all(self) -> QuerySet[_T]: ...
|
||||||
|
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
|
||||||
|
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
|
||||||
|
def complex_filter(
|
||||||
|
self,
|
||||||
|
filter_obj: Union[
|
||||||
|
Dict[str, datetime], Dict[str, QuerySet], Q, MagicMock
|
||||||
|
],
|
||||||
|
) -> QuerySet[_T]: ...
|
||||||
|
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
|
||||||
|
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
|
||||||
|
|
||||||
|
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
|
||||||
|
|
||||||
|
def select_for_update(
|
||||||
|
self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...
|
||||||
|
) -> QuerySet: ...
|
||||||
|
|
||||||
|
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
|
||||||
|
|
||||||
|
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
|
||||||
|
|
||||||
|
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
|
||||||
|
|
||||||
|
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
|
||||||
|
|
||||||
|
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
|
||||||
|
|
||||||
|
def extra(
|
||||||
|
self,
|
||||||
|
select: Optional[
|
||||||
|
Union[Dict[str, int], Dict[str, str], OrderedDict]
|
||||||
|
] = ...,
|
||||||
|
where: Optional[List[str]] = ...,
|
||||||
|
params: Optional[Union[List[int], List[str]]] = ...,
|
||||||
|
tables: Optional[List[str]] = ...,
|
||||||
|
order_by: Optional[Union[List[str], Tuple[str]]] = ...,
|
||||||
|
select_params: Optional[Union[List[int], List[str], Tuple[int]]] = ...,
|
||||||
|
) -> QuerySet[_T]: ...
|
||||||
|
|
||||||
|
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
|
||||||
|
|
||||||
|
def aggregate(
|
||||||
|
self, *args: Any, **kwargs: Any
|
||||||
|
) -> Dict[str, Optional[Union[datetime, float]]]: ...
|
||||||
|
|
||||||
|
def count(self) -> int: ...
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, *args: Any, **kwargs: Any
|
||||||
|
) -> _T: ...
|
||||||
|
|
||||||
|
def create(self, **kwargs: Any) -> _T: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ManagerDescriptor:
|
||||||
|
manager: Manager = ...
|
||||||
|
def __init__(self, manager: Manager) -> None: ...
|
||||||
|
def __get__(
|
||||||
|
self, instance: Optional[Model], cls: Type[Model] = ...
|
||||||
|
) -> Manager: ...
|
||||||
|
|
||||||
|
class EmptyManager(Manager):
|
||||||
|
creation_counter: int
|
||||||
|
name: None
|
||||||
|
model: Optional[Type[Model]] = ...
|
||||||
|
def __init__(self, model: Type[Model]) -> None: ...
|
||||||
|
def get_queryset(self) -> QuerySet: ...
|
||||||
@@ -7,3 +7,8 @@ from .conf import (include as include,
|
|||||||
from .resolvers import (ResolverMatch as ResolverMatch,
|
from .resolvers import (ResolverMatch as ResolverMatch,
|
||||||
get_ns_resolver as get_ns_resolver,
|
get_ns_resolver as get_ns_resolver,
|
||||||
get_resolver as get_resolver)
|
get_resolver as get_resolver)
|
||||||
|
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
from .converters import (
|
||||||
|
register_converter as register_converter
|
||||||
|
)
|
||||||
30
django-stubs/utils/baseconv.pyi
Normal file
30
django-stubs/utils/baseconv.pyi
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from typing import Any, Tuple, Union
|
||||||
|
|
||||||
|
BASE2_ALPHABET: str
|
||||||
|
BASE16_ALPHABET: str
|
||||||
|
BASE56_ALPHABET: str
|
||||||
|
BASE36_ALPHABET: str
|
||||||
|
BASE62_ALPHABET: str
|
||||||
|
BASE64_ALPHABET: Any
|
||||||
|
|
||||||
|
class BaseConverter:
|
||||||
|
decimal_digits: str = ...
|
||||||
|
sign: str = ...
|
||||||
|
digits: str = ...
|
||||||
|
def __init__(self, digits: str, sign: str = ...) -> None: ...
|
||||||
|
def encode(self, i: int) -> str: ...
|
||||||
|
def decode(self, s: str) -> int: ...
|
||||||
|
def convert(
|
||||||
|
self,
|
||||||
|
number: Union[int, str],
|
||||||
|
from_digits: str,
|
||||||
|
to_digits: str,
|
||||||
|
sign: str,
|
||||||
|
) -> Tuple[int, str]: ...
|
||||||
|
|
||||||
|
base2: Any
|
||||||
|
base16: Any
|
||||||
|
base36: Any
|
||||||
|
base56: Any
|
||||||
|
base62: Any
|
||||||
|
base64: Any
|
||||||
@@ -1,15 +1,15 @@
|
|||||||
from typing import Dict, Optional, NamedTuple
|
import typing
|
||||||
|
from typing import Dict, Optional, NamedTuple, Any
|
||||||
|
|
||||||
from mypy.semanal import SemanticAnalyzerPass2
|
|
||||||
from mypy.types import Type
|
|
||||||
from mypy.nodes import SymbolTableNode, Var, Expression
|
from mypy.nodes import SymbolTableNode, Var, Expression
|
||||||
from mypy.plugin import FunctionContext
|
from mypy.plugin import FunctionContext
|
||||||
from mypy.types import Instance, UnionType, NoneTyp
|
from mypy.types import Type, Instance, UnionType, NoneTyp
|
||||||
|
|
||||||
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
|
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
|
||||||
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
|
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
|
||||||
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
|
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
|
||||||
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
|
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
|
||||||
|
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
|
||||||
|
|
||||||
|
|
||||||
def create_new_symtable_node(name: str, kind: int, instance: Instance) -> SymbolTableNode:
|
def create_new_symtable_node(name: str, kind: int, instance: Instance) -> SymbolTableNode:
|
||||||
@@ -26,12 +26,10 @@ Argument = NamedTuple('Argument', fields=[
|
|||||||
|
|
||||||
|
|
||||||
def get_call_signature_or_none(ctx: FunctionContext) -> Optional[Dict[str, Argument]]:
|
def get_call_signature_or_none(ctx: FunctionContext) -> Optional[Dict[str, Argument]]:
|
||||||
arg_names = ctx.context.arg_names
|
|
||||||
|
|
||||||
result: Dict[str, Argument] = {}
|
result: Dict[str, Argument] = {}
|
||||||
positional_args_only = []
|
positional_args_only = []
|
||||||
positional_arg_types_only = []
|
positional_arg_types_only = []
|
||||||
for arg, arg_name, arg_type in zip(ctx.args, arg_names, ctx.arg_types):
|
for arg, arg_name, arg_type in zip(ctx.args, ctx.arg_names, ctx.arg_types):
|
||||||
if arg_name is None:
|
if arg_name is None:
|
||||||
positional_args_only.append(arg)
|
positional_args_only.append(arg)
|
||||||
positional_arg_types_only.append(arg_type)
|
positional_arg_types_only.append(arg_type)
|
||||||
@@ -65,3 +63,7 @@ def make_required(typ: Type) -> Type:
|
|||||||
return typ
|
return typ
|
||||||
items = [item for item in typ.items if not isinstance(item, NoneTyp)]
|
items = [item for item in typ.items if not isinstance(item, NoneTyp)]
|
||||||
return UnionType.make_union(items)
|
return UnionType.make_union(items)
|
||||||
|
|
||||||
|
|
||||||
|
def get_obj_type_name(typ: typing.Type) -> str:
|
||||||
|
return typ.__module__ + '.' + typ.__qualname__
|
||||||
|
|||||||
@@ -1,46 +1,79 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, List
|
||||||
|
|
||||||
|
from django.apps.registry import Apps
|
||||||
from django.conf import Settings
|
from django.conf import Settings
|
||||||
|
from mypy import build
|
||||||
|
from mypy.build import BuildManager
|
||||||
from mypy.options import Options
|
from mypy.options import Options
|
||||||
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
|
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
|
||||||
from mypy.types import Type
|
from mypy.types import Type
|
||||||
|
|
||||||
from mypy_django_plugin import helpers
|
from mypy_django_plugin import helpers, monkeypatch
|
||||||
from mypy_django_plugin.plugins.objects_queryset import set_objects_queryset_to_model_class
|
from mypy_django_plugin.plugins.objects_queryset import set_objects_queryset_to_model_class
|
||||||
from mypy_django_plugin.plugins.postgres_fields import determine_type_of_array_field
|
from mypy_django_plugin.plugins.postgres_fields import determine_type_of_array_field
|
||||||
from mypy_django_plugin.plugins.related_fields import set_related_name_instance_for_onetoonefield, \
|
from mypy_django_plugin.plugins.related_fields import OneToOneFieldHook, \
|
||||||
set_related_name_manager_for_foreign_key, set_fieldname_attrs_for_related_fields
|
ForeignKeyHook, set_fieldname_attrs_for_related_fields
|
||||||
from mypy_django_plugin.plugins.setup_settings import DjangoConfSettingsInitializerHook
|
from mypy_django_plugin.plugins.setup_settings import DjangoConfSettingsInitializerHook
|
||||||
|
|
||||||
|
|
||||||
base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
|
base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
|
||||||
|
|
||||||
|
|
||||||
def transform_model_class(ctx: ClassDefContext) -> None:
|
class TransformModelClassHook(object):
|
||||||
|
def __init__(self, settings: Settings, apps: Apps):
|
||||||
|
self.settings = settings
|
||||||
|
self.apps = apps
|
||||||
|
|
||||||
|
def __call__(self, ctx: ClassDefContext) -> None:
|
||||||
base_model_classes.add(ctx.cls.fullname)
|
base_model_classes.add(ctx.cls.fullname)
|
||||||
|
|
||||||
set_fieldname_attrs_for_related_fields(ctx)
|
set_fieldname_attrs_for_related_fields(ctx)
|
||||||
set_objects_queryset_to_model_class(ctx)
|
set_objects_queryset_to_model_class(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
def always_return_none(manager: BuildManager):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
build.read_plugins_snapshot = always_return_none
|
||||||
|
|
||||||
|
|
||||||
class DjangoPlugin(Plugin):
|
class DjangoPlugin(Plugin):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
options: Options) -> None:
|
options: Options) -> None:
|
||||||
super().__init__(options)
|
super().__init__(options)
|
||||||
self.django_settings = None
|
self.django_settings = None
|
||||||
|
self.apps = None
|
||||||
|
|
||||||
|
monkeypatch.replace_apply_function_plugin_method()
|
||||||
|
|
||||||
django_settings_module = os.environ.get('DJANGO_SETTINGS_MODULE')
|
django_settings_module = os.environ.get('DJANGO_SETTINGS_MODULE')
|
||||||
if django_settings_module:
|
if django_settings_module:
|
||||||
self.django_settings = Settings(django_settings_module)
|
self.django_settings = Settings(django_settings_module)
|
||||||
|
# import django
|
||||||
|
# django.setup()
|
||||||
|
#
|
||||||
|
# from django.apps import apps
|
||||||
|
# self.apps = apps
|
||||||
|
#
|
||||||
|
# models_modules = []
|
||||||
|
# for app_config in self.apps.app_configs.values():
|
||||||
|
# models_modules.append(app_config.module.__name__ + '.' + 'models')
|
||||||
|
#
|
||||||
|
# monkeypatch.state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(django_settings_module,
|
||||||
|
# models_modules)
|
||||||
|
monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(django_settings_module)
|
||||||
|
|
||||||
def get_function_hook(self, fullname: str
|
def get_function_hook(self, fullname: str
|
||||||
) -> Optional[Callable[[FunctionContext], Type]]:
|
) -> Optional[Callable[[FunctionContext], Type]]:
|
||||||
if fullname == helpers.FOREIGN_KEY_FULLNAME:
|
if fullname == helpers.FOREIGN_KEY_FULLNAME:
|
||||||
return set_related_name_manager_for_foreign_key
|
return ForeignKeyHook(settings=self.django_settings,
|
||||||
|
apps=self.apps)
|
||||||
|
|
||||||
if fullname == helpers.ONETOONE_FIELD_FULLNAME:
|
if fullname == helpers.ONETOONE_FIELD_FULLNAME:
|
||||||
return set_related_name_instance_for_onetoonefield
|
return OneToOneFieldHook(settings=self.django_settings,
|
||||||
|
apps=self.apps)
|
||||||
|
|
||||||
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
|
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
|
||||||
return determine_type_of_array_field
|
return determine_type_of_array_field
|
||||||
@@ -49,9 +82,11 @@ class DjangoPlugin(Plugin):
|
|||||||
def get_base_class_hook(self, fullname: str
|
def get_base_class_hook(self, fullname: str
|
||||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||||
if fullname in base_model_classes:
|
if fullname in base_model_classes:
|
||||||
return transform_model_class
|
return TransformModelClassHook(self.django_settings, self.apps)
|
||||||
if fullname == 'django.conf._DjangoConfLazyObject':
|
|
||||||
|
if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS:
|
||||||
return DjangoConfSettingsInitializerHook(settings=self.django_settings)
|
return DjangoConfSettingsInitializerHook(settings=self.django_settings)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
112
mypy_django_plugin/monkeypatch.py
Normal file
112
mypy_django_plugin/monkeypatch.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
from typing import Optional, List, Sequence
|
||||||
|
|
||||||
|
from mypy.build import BuildManager, Graph, State
|
||||||
|
from mypy.modulefinder import BuildSource
|
||||||
|
from mypy.nodes import Expression, Context
|
||||||
|
from mypy.plugin import FunctionContext, MethodContext
|
||||||
|
from mypy.types import Type, CallableType, Instance
|
||||||
|
|
||||||
|
|
||||||
|
def state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(settings_module: str,
|
||||||
|
models_py_modules: List[str]):
|
||||||
|
from mypy.build import State
|
||||||
|
|
||||||
|
old_compute_dependencies = State.compute_dependencies
|
||||||
|
|
||||||
|
def patched_compute_dependencies(self: State):
|
||||||
|
old_compute_dependencies(self)
|
||||||
|
if self.id == settings_module:
|
||||||
|
self.dependencies.extend(models_py_modules)
|
||||||
|
|
||||||
|
State.compute_dependencies = patched_compute_dependencies
|
||||||
|
|
||||||
|
|
||||||
|
def load_graph_to_add_settings_file_as_a_source_seed(settings_module: str):
|
||||||
|
from mypy import build
|
||||||
|
|
||||||
|
old_load_graph = build.load_graph
|
||||||
|
|
||||||
|
def patched_load_graph(sources: List[BuildSource], manager: BuildManager,
|
||||||
|
old_graph: Optional[Graph] = None,
|
||||||
|
new_modules: Optional[List[State]] = None):
|
||||||
|
if all([source.module != settings_module for source in sources]):
|
||||||
|
sources.append(BuildSource(None, settings_module, None))
|
||||||
|
|
||||||
|
return old_load_graph(sources=sources, manager=manager,
|
||||||
|
old_graph=old_graph,
|
||||||
|
new_modules=new_modules)
|
||||||
|
|
||||||
|
build.load_graph = patched_load_graph
|
||||||
|
|
||||||
|
|
||||||
|
def replace_apply_function_plugin_method():
|
||||||
|
def apply_function_plugin(self,
|
||||||
|
arg_types: List[Type],
|
||||||
|
inferred_ret_type: Type,
|
||||||
|
arg_names: Optional[Sequence[Optional[str]]],
|
||||||
|
formal_to_actual: List[List[int]],
|
||||||
|
args: List[Expression],
|
||||||
|
num_formals: int,
|
||||||
|
fullname: str,
|
||||||
|
object_type: Optional[Type],
|
||||||
|
context: Context) -> Type:
|
||||||
|
"""Use special case logic to infer the return type of a specific named function/method.
|
||||||
|
|
||||||
|
Caller must ensure that a plugin hook exists. There are two different cases:
|
||||||
|
|
||||||
|
- If object_type is None, the caller must ensure that a function hook exists
|
||||||
|
for fullname.
|
||||||
|
- If object_type is not None, the caller must ensure that a method hook exists
|
||||||
|
for fullname.
|
||||||
|
|
||||||
|
Return the inferred return type.
|
||||||
|
"""
|
||||||
|
formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]]
|
||||||
|
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
|
||||||
|
formal_arg_names = [None for _ in range(num_formals)] # type: List[Optional[str]]
|
||||||
|
for formal, actuals in enumerate(formal_to_actual):
|
||||||
|
for actual in actuals:
|
||||||
|
formal_arg_types[formal].append(arg_types[actual])
|
||||||
|
formal_arg_exprs[formal].append(args[actual])
|
||||||
|
if arg_names:
|
||||||
|
formal_arg_names[formal] = arg_names[actual]
|
||||||
|
|
||||||
|
num_passed_positionals = sum([1 if name is None else 0
|
||||||
|
for name in formal_arg_names])
|
||||||
|
if arg_names and num_passed_positionals > 0:
|
||||||
|
object_type_info = None
|
||||||
|
if object_type is not None:
|
||||||
|
if isinstance(object_type, CallableType):
|
||||||
|
# class object, convert to corresponding Instance
|
||||||
|
object_type = object_type.ret_type
|
||||||
|
if isinstance(object_type, Instance):
|
||||||
|
# skip TypedDictType and others
|
||||||
|
object_type_info = object_type.type
|
||||||
|
|
||||||
|
defn_arg_names = self._get_defn_arg_names(fullname, object_type=object_type_info)
|
||||||
|
if defn_arg_names:
|
||||||
|
if num_formals < len(defn_arg_names):
|
||||||
|
# self/cls argument has been passed implicitly
|
||||||
|
defn_arg_names = defn_arg_names[1:]
|
||||||
|
formal_arg_names[:num_passed_positionals] = defn_arg_names[:num_passed_positionals]
|
||||||
|
|
||||||
|
if object_type is None:
|
||||||
|
# Apply function plugin
|
||||||
|
callback = self.plugin.get_function_hook(fullname)
|
||||||
|
assert callback is not None # Assume that caller ensures this
|
||||||
|
return callback(
|
||||||
|
FunctionContext(formal_arg_names, formal_arg_types,
|
||||||
|
inferred_ret_type, formal_arg_exprs,
|
||||||
|
context, self.chk))
|
||||||
|
else:
|
||||||
|
# Apply method plugin
|
||||||
|
method_callback = self.plugin.get_method_hook(fullname)
|
||||||
|
assert method_callback is not None # Assume that caller ensures this
|
||||||
|
return method_callback(
|
||||||
|
MethodContext(object_type, formal_arg_names, formal_arg_types,
|
||||||
|
inferred_ret_type, formal_arg_exprs,
|
||||||
|
context, self.chk))
|
||||||
|
|
||||||
|
from mypy.checkexpr import ExpressionChecker
|
||||||
|
ExpressionChecker.apply_function_plugin = apply_function_plugin
|
||||||
|
|
||||||
@@ -9,19 +9,28 @@ from mypy_django_plugin import helpers
|
|||||||
|
|
||||||
|
|
||||||
def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None:
|
def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None:
|
||||||
if 'objects' in ctx.cls.info.names:
|
# search over mro
|
||||||
return
|
objects_sym = ctx.cls.info.get('objects')
|
||||||
api = cast(SemanticAnalyzerPass2, ctx.api)
|
if objects_sym is not None:
|
||||||
|
return None
|
||||||
|
|
||||||
metaclass_node = ctx.cls.info.names.get('Meta')
|
# only direct Meta class
|
||||||
if metaclass_node is not None:
|
metaclass_sym = ctx.cls.info.names.get('Meta')
|
||||||
for stmt in metaclass_node.node.defn.defs.body:
|
# skip if abstract
|
||||||
|
if metaclass_sym is not None:
|
||||||
|
for stmt in metaclass_sym.node.defn.defs.body:
|
||||||
if (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1
|
if (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1
|
||||||
and stmt.lvalues[0].name == 'abstract'):
|
and stmt.lvalues[0].name == 'abstract'):
|
||||||
is_abstract = api.parse_bool(stmt.rvalue)
|
is_abstract = ctx.api.parse_bool(stmt.rvalue)
|
||||||
if is_abstract:
|
if is_abstract:
|
||||||
return
|
return None
|
||||||
|
|
||||||
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, args=[Instance(ctx.cls.info, [])])
|
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||||
new_objects_node = helpers.create_new_symtable_node('objects', MDEF, instance=typ)
|
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME,
|
||||||
ctx.cls.info.names['objects'] = new_objects_node
|
args=[Instance(ctx.cls.info, [])])
|
||||||
|
if not typ:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ctx.cls.info.names['objects'] = helpers.create_new_symtable_node('objects',
|
||||||
|
kind=MDEF,
|
||||||
|
instance=typ)
|
||||||
|
|||||||
@@ -1,14 +1,11 @@
|
|||||||
from mypy.plugin import FunctionContext
|
from mypy.plugin import FunctionContext
|
||||||
from mypy.types import Type
|
from mypy.types import Type
|
||||||
|
|
||||||
from mypy_django_plugin import helpers
|
|
||||||
|
|
||||||
|
|
||||||
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
|
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
|
||||||
signature = helpers.get_call_signature_or_none(ctx)
|
if 'base_field' not in ctx.arg_names:
|
||||||
if signature is None:
|
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
_, base_field_arg_type = signature['base_field']
|
base_field_arg_type = ctx.arg_types[ctx.arg_names.index('base_field')][0]
|
||||||
return ctx.api.named_generic_type(ctx.context.callee.fullname,
|
return ctx.api.named_generic_type(ctx.context.callee.fullname,
|
||||||
args=[base_field_arg_type.type.names['__get__'].type.ret_type])
|
args=[base_field_arg_type.type.names['__get__'].type.ret_type])
|
||||||
|
|||||||
@@ -1,23 +1,63 @@
|
|||||||
from typing import Optional, cast
|
import typing
|
||||||
|
from typing import Optional, cast, Tuple, Any
|
||||||
|
|
||||||
|
from django.apps.registry import Apps
|
||||||
|
from django.conf import Settings
|
||||||
|
from django.db import models
|
||||||
from mypy.checker import TypeChecker
|
from mypy.checker import TypeChecker
|
||||||
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, MemberExpr
|
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, StrExpr
|
||||||
from mypy.plugin import FunctionContext, ClassDefContext
|
from mypy.plugin import FunctionContext, ClassDefContext
|
||||||
from mypy.types import Type, CallableType, Instance, AnyType
|
from mypy.types import Type, CallableType, Instance, AnyType
|
||||||
|
|
||||||
from mypy_django_plugin import helpers
|
from mypy_django_plugin import helpers
|
||||||
|
|
||||||
|
|
||||||
def extract_to_value_type(ctx: FunctionContext) -> Optional[Instance]:
|
def get_instance_type_for_class(klass: typing.Type[models.Model],
|
||||||
signature = helpers.get_call_signature_or_none(ctx)
|
api: TypeChecker) -> Optional[Instance]:
|
||||||
if signature is None or 'to' not in signature:
|
model_qualname = helpers.get_obj_type_name(klass)
|
||||||
return None
|
module_name, _, class_name = model_qualname.rpartition('.')
|
||||||
|
module = api.modules.get(module_name)
|
||||||
|
if not module or class_name not in module.names:
|
||||||
|
return
|
||||||
|
|
||||||
arg, arg_type = signature['to']
|
sym = module.names[class_name]
|
||||||
if not isinstance(arg_type, CallableType):
|
return Instance(sym.node, [])
|
||||||
return None
|
|
||||||
|
|
||||||
return arg_type.ret_type
|
|
||||||
|
def extract_to_value_type(ctx: FunctionContext,
|
||||||
|
apps: Optional[Apps]) -> Tuple[Optional[Instance], bool]:
|
||||||
|
api = cast(TypeChecker, ctx.api)
|
||||||
|
|
||||||
|
if 'to' not in ctx.arg_names:
|
||||||
|
return None, False
|
||||||
|
arg = ctx.args[ctx.arg_names.index('to')][0]
|
||||||
|
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
|
||||||
|
|
||||||
|
if isinstance(arg_type, CallableType):
|
||||||
|
return arg_type.ret_type, False
|
||||||
|
|
||||||
|
if apps:
|
||||||
|
if isinstance(arg, StrExpr):
|
||||||
|
arg_value = arg.value
|
||||||
|
if '.' not in arg_value:
|
||||||
|
return None, False
|
||||||
|
|
||||||
|
app_label, modelname = arg_value.lower().split('.')
|
||||||
|
try:
|
||||||
|
model_cls = apps.get_model(app_label, modelname)
|
||||||
|
except LookupError:
|
||||||
|
# no model class found
|
||||||
|
return None, False
|
||||||
|
try:
|
||||||
|
instance = get_instance_type_for_class(model_cls, api=api)
|
||||||
|
if not instance:
|
||||||
|
return None, False
|
||||||
|
return instance, True
|
||||||
|
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None, False
|
||||||
|
|
||||||
|
|
||||||
def extract_related_name_value(ctx: FunctionContext) -> str:
|
def extract_related_name_value(ctx: FunctionContext) -> str:
|
||||||
@@ -30,43 +70,56 @@ def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instanc
|
|||||||
instance=new_member_instance)
|
instance=new_member_instance)
|
||||||
|
|
||||||
|
|
||||||
def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
|
class ForeignKeyHook(object):
|
||||||
|
def __init__(self, settings: Settings, apps: Apps):
|
||||||
|
self.settings = settings
|
||||||
|
self.apps = apps
|
||||||
|
|
||||||
|
def __call__(self, ctx: FunctionContext) -> Type:
|
||||||
api = cast(TypeChecker, ctx.api)
|
api = cast(TypeChecker, ctx.api)
|
||||||
outer_class_info = api.tscope.classes[-1]
|
outer_class_info = api.tscope.classes[-1]
|
||||||
|
|
||||||
if 'related_name' not in ctx.context.arg_names:
|
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps)
|
||||||
return ctx.default_return_type
|
|
||||||
|
|
||||||
referred_to = extract_to_value_type(ctx)
|
|
||||||
if not referred_to:
|
if not referred_to:
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
if 'related_name' in ctx.context.arg_names:
|
||||||
related_name = extract_related_name_value(ctx)
|
related_name = extract_related_name_value(ctx)
|
||||||
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
|
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
|
||||||
args=[Instance(outer_class_info, [])])
|
args=[Instance(outer_class_info, [])])
|
||||||
if isinstance(referred_to, AnyType):
|
if isinstance(referred_to, AnyType):
|
||||||
# referred_to defined as string, which is unsupported for now
|
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
add_new_class_member(referred_to.type,
|
add_new_class_member(referred_to.type,
|
||||||
related_name, queryset_type)
|
related_name, queryset_type)
|
||||||
|
if is_string_based:
|
||||||
|
return referred_to
|
||||||
|
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
|
||||||
def set_related_name_instance_for_onetoonefield(ctx: FunctionContext) -> Type:
|
class OneToOneFieldHook(object):
|
||||||
|
def __init__(self, settings: Optional[Settings], apps: Optional[Apps]):
|
||||||
|
self.settings = settings
|
||||||
|
self.apps = apps
|
||||||
|
|
||||||
|
def __call__(self, ctx: FunctionContext) -> Type:
|
||||||
if 'related_name' not in ctx.context.arg_names:
|
if 'related_name' not in ctx.context.arg_names:
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
referred_to = extract_to_value_type(ctx)
|
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps)
|
||||||
if referred_to is None:
|
if referred_to is None:
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
if 'related_name' in ctx.context.arg_names:
|
||||||
related_name = extract_related_name_value(ctx)
|
related_name = extract_related_name_value(ctx)
|
||||||
outer_class_info = ctx.api.tscope.classes[-1]
|
outer_class_info = ctx.api.tscope.classes[-1]
|
||||||
|
|
||||||
api = cast(TypeChecker, ctx.api)
|
|
||||||
add_new_class_member(referred_to.type, related_name,
|
add_new_class_member(referred_to.type, related_name,
|
||||||
new_member_instance=api.named_type(outer_class_info.fullname()))
|
new_member_instance=Instance(outer_class_info, []))
|
||||||
|
|
||||||
|
if is_string_based:
|
||||||
|
return referred_to
|
||||||
|
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from typing import cast, Any
|
from typing import cast
|
||||||
|
|
||||||
from django.conf import Settings
|
from django.conf import Settings
|
||||||
from mypy.nodes import MDEF, TypeInfo, SymbolTable
|
from mypy.nodes import MDEF
|
||||||
from mypy.plugin import ClassDefContext
|
from mypy.plugin import ClassDefContext
|
||||||
from mypy.semanal import SemanticAnalyzerPass2
|
from mypy.semanal import SemanticAnalyzerPass2
|
||||||
from mypy.types import Instance, AnyType, TypeOfAny
|
from mypy.types import Instance, AnyType, TypeOfAny
|
||||||
@@ -9,24 +9,22 @@ from mypy.types import Instance, AnyType, TypeOfAny
|
|||||||
from mypy_django_plugin import helpers
|
from mypy_django_plugin import helpers
|
||||||
|
|
||||||
|
|
||||||
def get_obj_type_name(value: Any) -> str:
|
|
||||||
return type(value).__module__ + '.' + type(value).__qualname__
|
|
||||||
|
|
||||||
|
|
||||||
class DjangoConfSettingsInitializerHook(object):
|
class DjangoConfSettingsInitializerHook(object):
|
||||||
def __init__(self, settings: Settings):
|
def __init__(self, settings: Settings):
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
|
|
||||||
def __call__(self, ctx: ClassDefContext) -> None:
|
def __call__(self, ctx: ClassDefContext) -> None:
|
||||||
api = cast(SemanticAnalyzerPass2, ctx.api)
|
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||||
|
if self.settings:
|
||||||
for name, value in self.settings.__dict__.items():
|
for name, value in self.settings.__dict__.items():
|
||||||
if name.isupper():
|
if name.isupper():
|
||||||
if value is None:
|
if value is None:
|
||||||
|
# TODO: change to Optional[Any] later
|
||||||
ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF,
|
ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF,
|
||||||
instance=api.builtin_type('builtins.object'))
|
instance=api.builtin_type('builtins.object'))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
type_fullname = get_obj_type_name(value)
|
type_fullname = helpers.get_obj_type_name(type(value))
|
||||||
sym = api.lookup_fully_qualified_or_none(type_fullname)
|
sym = api.lookup_fully_qualified_or_none(type_fullname)
|
||||||
if sym is not None:
|
if sym is not None:
|
||||||
args = len(sym.node.type_vars) * [AnyType(TypeOfAny.from_omitted_generics)]
|
args = len(sym.node.type_vars) * [AnyType(TypeOfAny.from_omitted_generics)]
|
||||||
|
|||||||
@@ -5,3 +5,4 @@ python_files = test*.py
|
|||||||
addopts =
|
addopts =
|
||||||
--tb=native
|
--tb=native
|
||||||
-v
|
-v
|
||||||
|
-s
|
||||||
@@ -147,6 +147,12 @@ class DjangoDataDrivenTestCase(DataDrivenTestCase):
|
|||||||
self.old_cwd = os.getcwd()
|
self.old_cwd = os.getcwd()
|
||||||
|
|
||||||
self.tmpdir = tempfile.TemporaryDirectory(prefix='mypy-test-')
|
self.tmpdir = tempfile.TemporaryDirectory(prefix='mypy-test-')
|
||||||
|
tmpdir_root = os.path.join(self.tmpdir.name, 'tmp')
|
||||||
|
|
||||||
|
new_files = []
|
||||||
|
for path, contents in self.files:
|
||||||
|
new_files.append((path, contents.replace('<TMP>', tmpdir_root)))
|
||||||
|
self.files = new_files
|
||||||
|
|
||||||
os.chdir(self.tmpdir.name)
|
os.chdir(self.tmpdir.name)
|
||||||
os.mkdir(test_temp_dir)
|
os.mkdir(test_temp_dir)
|
||||||
@@ -179,7 +185,7 @@ class DjangoDataDrivenTestCase(DataDrivenTestCase):
|
|||||||
self.clean_up.append((True, d))
|
self.clean_up.append((True, d))
|
||||||
self.clean_up.append((False, path))
|
self.clean_up.append((False, path))
|
||||||
|
|
||||||
sys.path.insert(0, os.path.join(self.tmpdir.name, 'tmp'))
|
sys.path.insert(0, tmpdir_root)
|
||||||
|
|
||||||
def teardown(self):
|
def teardown(self):
|
||||||
if hasattr(self, 'old_environ'):
|
if hasattr(self, 'old_environ'):
|
||||||
|
|||||||
253
test/helpers.py
Normal file
253
test/helpers.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import List, Callable, Optional, Tuple
|
||||||
|
|
||||||
|
import pytest # type: ignore # no pytest in typeshed
|
||||||
|
|
||||||
|
skip = pytest.mark.skip
|
||||||
|
|
||||||
|
# AssertStringArraysEqual displays special line alignment helper messages if
|
||||||
|
# the first different line has at least this many characters,
|
||||||
|
MIN_LINE_LENGTH_FOR_ALIGNMENT = 5
|
||||||
|
|
||||||
|
|
||||||
|
class TypecheckAssertionError(AssertionError):
|
||||||
|
def __init__(self, error_message: str, lineno: int):
|
||||||
|
self.error_message = error_message
|
||||||
|
self.lineno = lineno
|
||||||
|
|
||||||
|
def first_line(self):
|
||||||
|
return self.__class__.__name__ + '(message="Invalid output")'
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.error_message
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_up(a: List[str]) -> List[str]:
|
||||||
|
"""Remove common directory prefix from all strings in a.
|
||||||
|
|
||||||
|
This uses a naive string replace; it seems to work well enough. Also
|
||||||
|
remove trailing carriage returns.
|
||||||
|
"""
|
||||||
|
res = []
|
||||||
|
for s in a:
|
||||||
|
prefix = os.sep
|
||||||
|
ss = s
|
||||||
|
for p in prefix, prefix.replace(os.sep, '/'):
|
||||||
|
if p != '/' and p != '//' and p != '\\' and p != '\\\\':
|
||||||
|
ss = ss.replace(p, '')
|
||||||
|
# Ignore spaces at end of line.
|
||||||
|
ss = re.sub(' +$', '', ss)
|
||||||
|
res.append(re.sub('\\r$', '', ss))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def _num_skipped_prefix_lines(a1: List[str], a2: List[str]) -> int:
|
||||||
|
num_eq = 0
|
||||||
|
while num_eq < min(len(a1), len(a2)) and a1[num_eq] == a2[num_eq]:
|
||||||
|
num_eq += 1
|
||||||
|
return max(0, num_eq - 4)
|
||||||
|
|
||||||
|
|
||||||
|
def _num_skipped_suffix_lines(a1: List[str], a2: List[str]) -> int:
|
||||||
|
num_eq = 0
|
||||||
|
while (num_eq < min(len(a1), len(a2))
|
||||||
|
and a1[-num_eq - 1] == a2[-num_eq - 1]):
|
||||||
|
num_eq += 1
|
||||||
|
return max(0, num_eq - 4)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_aligned_message(s1: str, s2: str, error_message: str) -> str:
|
||||||
|
"""Align s1 and s2 so that the their first difference is highlighted.
|
||||||
|
|
||||||
|
For example, if s1 is 'foobar' and s2 is 'fobar', display the
|
||||||
|
following lines:
|
||||||
|
|
||||||
|
E: foobar
|
||||||
|
A: fobar
|
||||||
|
^
|
||||||
|
|
||||||
|
If s1 and s2 are long, only display a fragment of the strings around the
|
||||||
|
first difference. If s1 is very short, do nothing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Seeing what went wrong is trivial even without alignment if the expected
|
||||||
|
# string is very short. In this case do nothing to simplify output.
|
||||||
|
if len(s1) < 4:
|
||||||
|
return error_message
|
||||||
|
|
||||||
|
maxw = 72 # Maximum number of characters shown
|
||||||
|
|
||||||
|
error_message += 'Alignment of first line difference:\n'
|
||||||
|
# sys.stderr.write('Alignment of first line difference:\n')
|
||||||
|
|
||||||
|
trunc = False
|
||||||
|
while s1[:30] == s2[:30]:
|
||||||
|
s1 = s1[10:]
|
||||||
|
s2 = s2[10:]
|
||||||
|
trunc = True
|
||||||
|
|
||||||
|
if trunc:
|
||||||
|
s1 = '...' + s1
|
||||||
|
s2 = '...' + s2
|
||||||
|
|
||||||
|
max_len = max(len(s1), len(s2))
|
||||||
|
extra = ''
|
||||||
|
if max_len > maxw:
|
||||||
|
extra = '...'
|
||||||
|
|
||||||
|
# Write a chunk of both lines, aligned.
|
||||||
|
error_message += ' E: {}{}\n'.format(s1[:maxw], extra)
|
||||||
|
# sys.stderr.write(' E: {}{}\n'.format(s1[:maxw], extra))
|
||||||
|
error_message += ' A: {}{}\n'.format(s2[:maxw], extra)
|
||||||
|
# sys.stderr.write(' A: {}{}\n'.format(s2[:maxw], extra))
|
||||||
|
# Write an indicator character under the different columns.
|
||||||
|
error_message += ' '
|
||||||
|
# sys.stderr.write(' ')
|
||||||
|
for j in range(min(maxw, max(len(s1), len(s2)))):
|
||||||
|
if s1[j:j + 1] != s2[j:j + 1]:
|
||||||
|
error_message += '^'
|
||||||
|
# sys.stderr.write('^') # Difference
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
error_message += ' '
|
||||||
|
# sys.stderr.write(' ') # Equal
|
||||||
|
error_message += '\n'
|
||||||
|
return error_message
|
||||||
|
# sys.stderr.write('\n')
|
||||||
|
|
||||||
|
|
||||||
|
def assert_string_arrays_equal(expected: List[str], actual: List[str]) -> None:
|
||||||
|
"""Assert that two string arrays are equal.
|
||||||
|
|
||||||
|
Display any differences in a human-readable form.
|
||||||
|
"""
|
||||||
|
|
||||||
|
actual = _clean_up(actual)
|
||||||
|
error_message = ''
|
||||||
|
|
||||||
|
if actual != expected:
|
||||||
|
num_skip_start = _num_skipped_prefix_lines(expected, actual)
|
||||||
|
num_skip_end = _num_skipped_suffix_lines(expected, actual)
|
||||||
|
|
||||||
|
error_message += 'Expected:\n'
|
||||||
|
# sys.stderr.write('Expected:\n')
|
||||||
|
|
||||||
|
# If omit some lines at the beginning, indicate it by displaying a line
|
||||||
|
# with '...'.
|
||||||
|
if num_skip_start > 0:
|
||||||
|
error_message += ' ...\n'
|
||||||
|
# sys.stderr.write(' ...\n')
|
||||||
|
|
||||||
|
# Keep track of the first different line.
|
||||||
|
first_diff = -1
|
||||||
|
|
||||||
|
# Display only this many first characters of identical lines.
|
||||||
|
width = 75
|
||||||
|
|
||||||
|
for i in range(num_skip_start, len(expected) - num_skip_end):
|
||||||
|
if i >= len(actual) or expected[i] != actual[i]:
|
||||||
|
if first_diff < 0:
|
||||||
|
first_diff = i
|
||||||
|
error_message += ' {:<45} (diff)'.format(expected[i])
|
||||||
|
# sys.stderr.write(' {:<45} (diff)'.format(expected[i]))
|
||||||
|
else:
|
||||||
|
e = expected[i]
|
||||||
|
error_message += ' ' + e[:width]
|
||||||
|
# sys.stderr.write(' ' + e[:width])
|
||||||
|
if len(e) > width:
|
||||||
|
error_message += '...'
|
||||||
|
# sys.stderr.write('...')
|
||||||
|
error_message += '\n'
|
||||||
|
# sys.stderr.write('\n')
|
||||||
|
if num_skip_end > 0:
|
||||||
|
error_message += ' ...\n'
|
||||||
|
# sys.stderr.write(' ...\n')
|
||||||
|
|
||||||
|
error_message += 'Actual:\n'
|
||||||
|
# sys.stderr.write('Actual:\n')
|
||||||
|
|
||||||
|
if num_skip_start > 0:
|
||||||
|
error_message += ' ...\n'
|
||||||
|
# sys.stderr.write(' ...\n')
|
||||||
|
|
||||||
|
for j in range(num_skip_start, len(actual) - num_skip_end):
|
||||||
|
if j >= len(expected) or expected[j] != actual[j]:
|
||||||
|
error_message += ' {:<45} (diff)'.format(actual[j])
|
||||||
|
# sys.stderr.write(' {:<45} (diff)'.format(actual[j]))
|
||||||
|
else:
|
||||||
|
a = actual[j]
|
||||||
|
error_message += ' ' + a[:width]
|
||||||
|
# sys.stderr.write(' ' + a[:width])
|
||||||
|
if len(a) > width:
|
||||||
|
error_message += '...'
|
||||||
|
# sys.stderr.write('...')
|
||||||
|
error_message += '\n'
|
||||||
|
# sys.stderr.write('\n')
|
||||||
|
if actual == []:
|
||||||
|
error_message += ' (empty)\n'
|
||||||
|
# sys.stderr.write(' (empty)\n')
|
||||||
|
if num_skip_end > 0:
|
||||||
|
error_message += ' ...\n'
|
||||||
|
# sys.stderr.write(' ...\n')
|
||||||
|
|
||||||
|
error_message += '\n'
|
||||||
|
# sys.stderr.write('\n')
|
||||||
|
|
||||||
|
if first_diff >= 0 and first_diff < len(actual) and (
|
||||||
|
len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
|
||||||
|
or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT):
|
||||||
|
# Display message that helps visualize the differences between two
|
||||||
|
# long lines.
|
||||||
|
error_message = _add_aligned_message(expected[first_diff], actual[first_diff],
|
||||||
|
error_message)
|
||||||
|
|
||||||
|
first_failure = expected[first_diff]
|
||||||
|
if first_failure:
|
||||||
|
lineno = int(first_failure.split(' ')[0].strip(':').split(':')[1])
|
||||||
|
raise TypecheckAssertionError(error_message=f'Invalid output: \n{error_message}',
|
||||||
|
lineno=lineno)
|
||||||
|
|
||||||
|
|
||||||
|
def build_output_line(fname: str, lnum: int, severity: str, message: str, col=None) -> str:
|
||||||
|
if col is None:
|
||||||
|
return f'{fname}:{lnum + 1}: {severity}: {message}'
|
||||||
|
else:
|
||||||
|
return f'{fname}:{lnum + 1}:{col}: {severity}: {message}'
|
||||||
|
|
||||||
|
|
||||||
|
def expand_errors(input_lines: List[str], fname: str) -> List[str]:
|
||||||
|
"""Transform comments such as '# E: message' or
|
||||||
|
'# E:3: message' in input.
|
||||||
|
|
||||||
|
The result is lines like 'fnam:line: error: message'.
|
||||||
|
"""
|
||||||
|
output_lines = []
|
||||||
|
for lnum, line in enumerate(input_lines):
|
||||||
|
# The first in the split things isn't a comment
|
||||||
|
for possible_err_comment in line.split(' # ')[1:]:
|
||||||
|
m = re.search(
|
||||||
|
r'^([ENW]):((?P<col>\d+):)? (?P<message>.*)$',
|
||||||
|
possible_err_comment.strip())
|
||||||
|
if m:
|
||||||
|
if m.group(1) == 'E':
|
||||||
|
severity = 'error'
|
||||||
|
elif m.group(1) == 'N':
|
||||||
|
severity = 'note'
|
||||||
|
elif m.group(1) == 'W':
|
||||||
|
severity = 'warning'
|
||||||
|
col = m.group('col')
|
||||||
|
output_lines.append(build_output_line(fname, lnum, severity,
|
||||||
|
message=m.group("message"),
|
||||||
|
col=col))
|
||||||
|
return output_lines
|
||||||
|
|
||||||
|
|
||||||
|
def get_func_first_lnum(attr: Callable[..., None]) -> Optional[Tuple[int, List[str]]]:
|
||||||
|
lines, _ = inspect.getsourcelines(attr)
|
||||||
|
for lnum, line in enumerate(lines):
|
||||||
|
no_space_line = line.strip()
|
||||||
|
if f'def {attr.__name__}' in no_space_line:
|
||||||
|
return lnum, lines[lnum + 1:]
|
||||||
|
raise ValueError(f'No line "def {attr.__name__}" found')
|
||||||
299
test/pytest_plugin.py
Normal file
299
test/pytest_plugin.py
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
import dataclasses
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import textwrap
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterator, Any, Optional, cast, List, Type, Callable, Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from _pytest._code.code import ReprFileLocation, ReprEntry, ExceptionInfo
|
||||||
|
from decorator import decorate
|
||||||
|
from mypy import api as mypy_api
|
||||||
|
|
||||||
|
from test import vistir
|
||||||
|
from test.helpers import assert_string_arrays_equal, TypecheckAssertionError, expand_errors, get_func_first_lnum
|
||||||
|
|
||||||
|
|
||||||
|
def reveal_type(obj: Any) -> None:
|
||||||
|
# noop method, just to get rid of "method is not resolved" errors
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def output(output_lines: str):
|
||||||
|
def decor(func: Callable[..., None]):
|
||||||
|
func.out = output_lines
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorate(func, wrapper)
|
||||||
|
|
||||||
|
return decor
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_that_defined_method(meth) -> Type['MypyTypecheckTestCase']:
|
||||||
|
if inspect.ismethod(meth):
|
||||||
|
for cls in inspect.getmro(meth.__self__.__class__):
|
||||||
|
if cls.__dict__.get(meth.__name__) is meth:
|
||||||
|
return cls
|
||||||
|
meth = meth.__func__ # fallback to __qualname__ parsing
|
||||||
|
if inspect.isfunction(meth):
|
||||||
|
cls = getattr(inspect.getmodule(meth),
|
||||||
|
meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0])
|
||||||
|
if issubclass(cls, MypyTypecheckTestCase):
|
||||||
|
return cls
|
||||||
|
return getattr(meth, '__objclass__', None) # handle special descriptor objects
|
||||||
|
|
||||||
|
|
||||||
|
def file(filename: str, make_parent_packages=False):
|
||||||
|
def decor(func: Callable[..., None]):
|
||||||
|
func.filename = filename
|
||||||
|
func.make_parent_packages = make_parent_packages
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decor
|
||||||
|
|
||||||
|
|
||||||
|
def env(**environ):
|
||||||
|
def decor(func: Callable[..., None]):
|
||||||
|
func.env = environ
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CreateFile:
|
||||||
|
sources: str
|
||||||
|
make_parent_packages: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MypyTypecheckMeta(type):
|
||||||
|
def __new__(mcs, name, bases, attrs):
|
||||||
|
cls = super().__new__(mcs, name, bases, attrs)
|
||||||
|
cls.files: Dict[str, CreateFile] = {}
|
||||||
|
|
||||||
|
for name, attr in attrs.items():
|
||||||
|
if inspect.isfunction(attr):
|
||||||
|
filename = getattr(attr, 'filename', None)
|
||||||
|
if not filename:
|
||||||
|
continue
|
||||||
|
make_parent_packages = getattr(attr, 'make_parent_packages', False)
|
||||||
|
sources = textwrap.dedent(''.join(get_func_first_lnum(attr)[1]))
|
||||||
|
if sources.strip() == 'pass':
|
||||||
|
sources = ''
|
||||||
|
cls.files[filename] = CreateFile(sources, make_parent_packages)
|
||||||
|
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
class MypyTypecheckTestCase(metaclass=MypyTypecheckMeta):
|
||||||
|
files = None
|
||||||
|
|
||||||
|
def ini_file(self) -> str:
|
||||||
|
return """
|
||||||
|
[mypy]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_ini_file_contents(self) -> Optional[str]:
|
||||||
|
raw_ini_file = self.ini_file()
|
||||||
|
if not raw_ini_file:
|
||||||
|
return raw_ini_file
|
||||||
|
return raw_ini_file.strip() + '\n'
|
||||||
|
|
||||||
|
|
||||||
|
class TraceLastReprEntry(ReprEntry):
|
||||||
|
def toterminal(self, tw):
|
||||||
|
self.reprfileloc.toterminal(tw)
|
||||||
|
for line in self.lines:
|
||||||
|
red = line.startswith("E ")
|
||||||
|
tw.line(line, bold=True, red=red)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def fname_to_module(fpath: Path, root_path: Path) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
relpath = fpath.relative_to(root_path).with_suffix('')
|
||||||
|
return str(relpath).replace(os.sep, '.')
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class MypyTypecheckItem(pytest.Item):
|
||||||
|
root_directory = '/run/testdata'
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
name: str,
|
||||||
|
parent: 'MypyTestsCollector',
|
||||||
|
klass: Type[MypyTypecheckTestCase],
|
||||||
|
source_code: str,
|
||||||
|
first_lineno: int,
|
||||||
|
ini_file_contents: Optional[str] = None,
|
||||||
|
expected_output_lines: Optional[List[str]] = None,
|
||||||
|
files: Optional[Dict[str, CreateFile]] = None,
|
||||||
|
custom_environment: Optional[Dict[str, Any]] = None):
|
||||||
|
super().__init__(name=name, parent=parent)
|
||||||
|
self.klass = klass
|
||||||
|
self.source_code = source_code
|
||||||
|
self.first_lineno = first_lineno
|
||||||
|
self.ini_file_contents = ini_file_contents
|
||||||
|
self.expected_output_lines = expected_output_lines
|
||||||
|
self.files = files
|
||||||
|
self.custom_environment = custom_environment
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def temp_directory(self) -> Path:
|
||||||
|
with tempfile.TemporaryDirectory(prefix='mypy-pytest-',
|
||||||
|
dir=self.root_directory) as tmpdir_name:
|
||||||
|
yield Path(self.root_directory) / tmpdir_name
|
||||||
|
|
||||||
|
def runtest(self):
|
||||||
|
with self.temp_directory() as tmpdir_path:
|
||||||
|
if not self.source_code:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.ini_file_contents:
|
||||||
|
mypy_ini_fpath = tmpdir_path / 'mypy.ini'
|
||||||
|
mypy_ini_fpath.write_text(self.ini_file_contents)
|
||||||
|
|
||||||
|
test_specific_modules = []
|
||||||
|
for fname, create_file in self.files.items():
|
||||||
|
fpath = tmpdir_path / fname
|
||||||
|
if create_file.make_parent_packages:
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
for parent in fpath.parents:
|
||||||
|
try:
|
||||||
|
parent.relative_to(tmpdir_path)
|
||||||
|
if parent != tmpdir_path:
|
||||||
|
parent_init_file = parent / '__init__.py'
|
||||||
|
parent_init_file.write_text('')
|
||||||
|
test_specific_modules.append(fname_to_module(parent,
|
||||||
|
root_path=tmpdir_path))
|
||||||
|
except ValueError:
|
||||||
|
break
|
||||||
|
|
||||||
|
fpath.write_text(create_file.sources)
|
||||||
|
test_specific_modules.append(fname_to_module(fpath,
|
||||||
|
root_path=tmpdir_path))
|
||||||
|
|
||||||
|
with vistir.temp_environ(), vistir.temp_path():
|
||||||
|
for key, val in (self.custom_environment or {}).items():
|
||||||
|
os.environ[key] = val
|
||||||
|
sys.path.insert(0, str(tmpdir_path))
|
||||||
|
|
||||||
|
mypy_cmd_options = self.prepare_mypy_cmd_options(config_file_path=mypy_ini_fpath)
|
||||||
|
main_fpath = tmpdir_path / 'main.py'
|
||||||
|
main_fpath.write_text(self.source_code)
|
||||||
|
mypy_cmd_options.append(str(main_fpath))
|
||||||
|
|
||||||
|
stdout, _, _ = mypy_api.run(mypy_cmd_options)
|
||||||
|
output_lines = []
|
||||||
|
for line in stdout.splitlines():
|
||||||
|
if ':' not in line:
|
||||||
|
continue
|
||||||
|
out_fpath, res_line = line.split(':', 1)
|
||||||
|
line = os.path.relpath(out_fpath, start=tmpdir_path) + ':' + res_line
|
||||||
|
output_lines.append(line.strip().replace('.py', ''))
|
||||||
|
|
||||||
|
for module in test_specific_modules:
|
||||||
|
if module in sys.modules:
|
||||||
|
del sys.modules[module]
|
||||||
|
raise ValueError
|
||||||
|
assert_string_arrays_equal(expected=self.expected_output_lines,
|
||||||
|
actual=output_lines)
|
||||||
|
|
||||||
|
def prepare_mypy_cmd_options(self, config_file_path: Path) -> List[str]:
|
||||||
|
mypy_cmd_options = [
|
||||||
|
'--show-traceback',
|
||||||
|
'--no-silence-site-packages'
|
||||||
|
]
|
||||||
|
python_version = '.'.join([str(part) for part in sys.version_info[:2]])
|
||||||
|
mypy_cmd_options.append(f'--python-version={python_version}')
|
||||||
|
if self.ini_file_contents:
|
||||||
|
mypy_cmd_options.append(f'--config-file={config_file_path}')
|
||||||
|
return mypy_cmd_options
|
||||||
|
|
||||||
|
def repr_failure(self, excinfo: ExceptionInfo) -> str:
|
||||||
|
if excinfo.errisinstance(SystemExit):
|
||||||
|
# We assume that before doing exit() (which raises SystemExit) we've printed
|
||||||
|
# enough context about what happened so that a stack trace is not useful.
|
||||||
|
# In particular, uncaught exceptions during semantic analysis or type checking
|
||||||
|
# call exit() and they already print out a stack trace.
|
||||||
|
return excinfo.exconly(tryshort=True)
|
||||||
|
elif excinfo.errisinstance(TypecheckAssertionError):
|
||||||
|
# with traceback removed
|
||||||
|
exception_repr = excinfo.getrepr(style='short')
|
||||||
|
exception_repr.reprcrash.message = ''
|
||||||
|
repr_file_location = ReprFileLocation(path=inspect.getfile(self.klass),
|
||||||
|
lineno=self.first_lineno + excinfo.value.lineno,
|
||||||
|
message='')
|
||||||
|
repr_tb_entry = TraceLastReprEntry(filelocrepr=repr_file_location,
|
||||||
|
lines=exception_repr.reprtraceback.reprentries[-1].lines[1:],
|
||||||
|
style='short',
|
||||||
|
reprlocals=None,
|
||||||
|
reprfuncargs=None)
|
||||||
|
exception_repr.reprtraceback.reprentries = [repr_tb_entry]
|
||||||
|
return exception_repr
|
||||||
|
else:
|
||||||
|
return super().repr_failure(excinfo, style='short')
|
||||||
|
|
||||||
|
def reportinfo(self):
|
||||||
|
return self.fspath, None, get_class_qualname(self.klass) + '::' + self.name
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_qualname(klass: type) -> str:
|
||||||
|
return klass.__module__ + '.' + klass.__name__
|
||||||
|
|
||||||
|
|
||||||
|
def extract_test_output(attr: Callable[..., None]) -> List[str]:
|
||||||
|
out_data: str = getattr(attr, 'out', None)
|
||||||
|
out_lines = []
|
||||||
|
if out_data:
|
||||||
|
for line in out_data.split('\n'):
|
||||||
|
line = line.strip()
|
||||||
|
out_lines.append(line)
|
||||||
|
return out_lines
|
||||||
|
|
||||||
|
|
||||||
|
class MypyTestsCollector(pytest.Class):
|
||||||
|
def get_ini_file_contents(self, contents: str) -> str:
|
||||||
|
return contents.strip() + '\n'
|
||||||
|
|
||||||
|
def collect(self) -> Iterator[pytest.Item]:
|
||||||
|
current_testcase = cast(MypyTypecheckTestCase, self.obj())
|
||||||
|
ini_file_contents = self.get_ini_file_contents(current_testcase.ini_file())
|
||||||
|
for attr_name in dir(current_testcase):
|
||||||
|
if attr_name.startswith('_test_'):
|
||||||
|
attr = getattr(self.obj, attr_name)
|
||||||
|
if inspect.isfunction(attr):
|
||||||
|
first_line_lnum, source_lines = get_func_first_lnum(attr)
|
||||||
|
func_first_line_in_file = inspect.getsourcelines(attr)[1] + first_line_lnum
|
||||||
|
|
||||||
|
output_from_decorator = extract_test_output(attr)
|
||||||
|
output_from_comments = expand_errors(source_lines, 'main')
|
||||||
|
custom_env = getattr(attr, 'env', None)
|
||||||
|
main_source_code = textwrap.dedent(''.join(source_lines))
|
||||||
|
yield MypyTypecheckItem(name=attr_name,
|
||||||
|
parent=self,
|
||||||
|
klass=current_testcase.__class__,
|
||||||
|
source_code=main_source_code,
|
||||||
|
first_lineno=func_first_line_in_file,
|
||||||
|
ini_file_contents=ini_file_contents,
|
||||||
|
expected_output_lines=output_from_comments
|
||||||
|
+ output_from_decorator,
|
||||||
|
files=current_testcase.__class__.files,
|
||||||
|
custom_environment=custom_env)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_pycollect_makeitem(collector: Any, name: str, obj: Any) -> Optional[MypyTestsCollector]:
|
||||||
|
# Only classes derived from DataSuite contain test cases, not the DataSuite class itself
|
||||||
|
if (isinstance(obj, type)
|
||||||
|
and issubclass(obj, MypyTypecheckTestCase)
|
||||||
|
and obj is not MypyTypecheckTestCase):
|
||||||
|
# Non-None result means this obj is a test case.
|
||||||
|
# The collect method of the returned DataSuiteCollector instance will be called later,
|
||||||
|
# with self.obj being obj.
|
||||||
|
return MypyTestsCollector(name, parent=collector)
|
||||||
0
test/pytest_tests/__init__.py
Normal file
0
test/pytest_tests/__init__.py
Normal file
9
test/pytest_tests/base.py
Normal file
9
test/pytest_tests/base.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from test.pytest_plugin import MypyTypecheckTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDjangoPluginTestCase(MypyTypecheckTestCase):
|
||||||
|
def ini_file(self):
|
||||||
|
return """
|
||||||
|
[mypy]
|
||||||
|
plugins = mypy_django_plugin.main
|
||||||
|
"""
|
||||||
21
test/pytest_tests/test_model_fields.py
Normal file
21
test/pytest_tests/test_model_fields.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from test.pytest_plugin import reveal_type
|
||||||
|
from test.pytest_tests.base import BaseDjangoPluginTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestBasicModelFields(BaseDjangoPluginTestCase):
|
||||||
|
def test_model_field_classes_present_as_primitives(self):
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class User(models.Model):
|
||||||
|
id = models.AutoField(primary_key=True)
|
||||||
|
small_int = models.SmallIntegerField()
|
||||||
|
name = models.CharField(max_length=255)
|
||||||
|
slug = models.SlugField(max_length=255)
|
||||||
|
text = models.TextField()
|
||||||
|
|
||||||
|
user = User()
|
||||||
|
reveal_type(user.id) # E: Revealed type is 'builtins.int'
|
||||||
|
reveal_type(user.small_int) # E: Revealed type is 'builtins.int'
|
||||||
|
reveal_type(user.name) # E: Revealed type is 'builtins.str'
|
||||||
|
reveal_type(user.slug) # E: Revealed type is 'builtins.str'
|
||||||
|
reveal_type(user.text) # E: Revealed type is 'builtins.str'
|
||||||
@@ -1,16 +1,9 @@
|
|||||||
from test.pytest_plugin import MypyTypecheckTestCase, reveal_type
|
from test.pytest_plugin import reveal_type
|
||||||
|
from test.pytest_tests.base import BaseDjangoPluginTestCase
|
||||||
|
|
||||||
|
|
||||||
class BaseDjangoPluginTestCase(MypyTypecheckTestCase):
|
class TestForeignKey(BaseDjangoPluginTestCase):
|
||||||
def ini_file(self):
|
def test_foreign_key_field(self):
|
||||||
return """
|
|
||||||
[mypy]
|
|
||||||
plugins = mypy_django_plugin.main
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class MyTestCase(BaseDjangoPluginTestCase):
|
|
||||||
def check_foreign_key_field(self):
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
class Publisher(models.Model):
|
class Publisher(models.Model):
|
||||||
@@ -26,7 +19,7 @@ class MyTestCase(BaseDjangoPluginTestCase):
|
|||||||
publisher = Publisher()
|
publisher = Publisher()
|
||||||
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
|
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
|
||||||
|
|
||||||
def check_every_foreign_key_creates_field_name_with_appended_id(self):
|
def test_every_foreign_key_creates_field_name_with_appended_id(self):
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
class Publisher(models.Model):
|
class Publisher(models.Model):
|
||||||
@@ -39,7 +32,7 @@ class MyTestCase(BaseDjangoPluginTestCase):
|
|||||||
book = Book()
|
book = Book()
|
||||||
reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
|
reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
|
||||||
|
|
||||||
def check_foreign_key_different_order_of_params(self):
|
def test_foreign_key_different_order_of_params(self):
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
class Publisher(models.Model):
|
class Publisher(models.Model):
|
||||||
@@ -54,3 +47,43 @@ class MyTestCase(BaseDjangoPluginTestCase):
|
|||||||
|
|
||||||
publisher = Publisher()
|
publisher = Publisher()
|
||||||
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
|
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
|
||||||
|
|
||||||
|
|
||||||
|
class TestOneToOneField(BaseDjangoPluginTestCase):
|
||||||
|
def test_onetoone_field(self):
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class User(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Profile(models.Model):
|
||||||
|
user = models.OneToOneField(to=User, on_delete=models.CASCADE, related_name='profile')
|
||||||
|
|
||||||
|
profile = Profile()
|
||||||
|
reveal_type(profile.user) # E: Revealed type is 'main.User*'
|
||||||
|
|
||||||
|
user = User()
|
||||||
|
reveal_type(user.profile) # E: Revealed type is 'main.Profile'
|
||||||
|
|
||||||
|
def test_onetoone_field_with_underscore_id(self):
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class User(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Profile(models.Model):
|
||||||
|
user = models.OneToOneField(to=User, on_delete=models.CASCADE, related_name='profile')
|
||||||
|
|
||||||
|
profile = Profile()
|
||||||
|
reveal_type(profile.user_id) # E: Revealed type is 'builtins.int'
|
||||||
|
|
||||||
|
def test_parameter_to_keyword_may_be_absent(self):
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class User(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Profile(models.Model):
|
||||||
|
user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile')
|
||||||
|
|
||||||
|
reveal_type(User().profile) # E: Revealed type is 'main.Profile'
|
||||||
|
|||||||
28
test/pytest_tests/test_objects_queryset.py
Normal file
28
test/pytest_tests/test_objects_queryset.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from test.pytest_plugin import reveal_type, output
|
||||||
|
from test.pytest_tests.base import BaseDjangoPluginTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestObjectsQueryset(BaseDjangoPluginTestCase):
|
||||||
|
def test_every_model_has_objects_queryset_available(self):
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class User(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
reveal_type(User.objects) # E: Revealed type is 'django.db.models.query.QuerySet[main.User]'
|
||||||
|
|
||||||
|
@output("""
|
||||||
|
main:10: error: Revealed type is 'Any'
|
||||||
|
main:10: error: "Type[ModelMixin]" has no attribute "objects"
|
||||||
|
""")
|
||||||
|
def test_objects_get_returns_model_instance(self):
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class ModelMixin(models.Model):
|
||||||
|
class Meta:
|
||||||
|
abstract = True
|
||||||
|
|
||||||
|
class User(ModelMixin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'
|
||||||
37
test/pytest_tests/test_parse_settings.py
Normal file
37
test/pytest_tests/test_parse_settings.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from test.pytest_plugin import reveal_type, file, env
|
||||||
|
from test.pytest_tests.base import BaseDjangoPluginTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseSettingsFromFile(BaseDjangoPluginTestCase):
|
||||||
|
@env(DJANGO_SETTINGS_MODULE='mysettings')
|
||||||
|
def test_case(self):
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
|
reveal_type(settings.ROOT_DIR) # E: Revealed type is 'builtins.str'
|
||||||
|
reveal_type(settings.OBJ) # E: Revealed type is 'django.utils.functional.LazyObject'
|
||||||
|
reveal_type(settings.NUMBERS) # E: Revealed type is 'builtins.list[Any]'
|
||||||
|
reveal_type(settings.DICT) # E: Revealed type is 'builtins.dict[Any, Any]'
|
||||||
|
|
||||||
|
@file('mysettings.py')
|
||||||
|
def mysettings_py_file(self):
|
||||||
|
SECRET_KEY = 112233
|
||||||
|
ROOT_DIR = '/etc'
|
||||||
|
NUMBERS = ['one', 'two']
|
||||||
|
DICT = {} # type: ignore
|
||||||
|
|
||||||
|
from django.utils.functional import LazyObject
|
||||||
|
|
||||||
|
OBJ = LazyObject()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettingInitializableToNone(BaseDjangoPluginTestCase):
|
||||||
|
@env(DJANGO_SETTINGS_MODULE='mysettings')
|
||||||
|
def test_case(self):
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
|
reveal_type(settings.NONE_SETTING) # E: Revealed type is 'builtins.object'
|
||||||
|
|
||||||
|
@file('mysettings.py')
|
||||||
|
def mysettings_py_file(self):
|
||||||
|
SECRET_KEY = 112233
|
||||||
|
NONE_SETTING = None
|
||||||
27
test/pytest_tests/test_postgres_fields.py
Normal file
27
test/pytest_tests/test_postgres_fields.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from test.pytest_plugin import reveal_type
|
||||||
|
from test.pytest_tests.base import BaseDjangoPluginTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestArrayField(BaseDjangoPluginTestCase):
|
||||||
|
def test_descriptor_access(self):
|
||||||
|
from django.db import models
|
||||||
|
from django.contrib.postgres.fields import ArrayField
|
||||||
|
|
||||||
|
class User(models.Model):
|
||||||
|
array = ArrayField(base_field=models.Field())
|
||||||
|
|
||||||
|
user = User()
|
||||||
|
reveal_type(user.array) # E: Revealed type is 'builtins.list[Any]'
|
||||||
|
|
||||||
|
def test_base_field_parsed_into_generic_attribute(self):
|
||||||
|
from django.db import models
|
||||||
|
from django.contrib.postgres.fields import ArrayField
|
||||||
|
|
||||||
|
class User(models.Model):
|
||||||
|
members = ArrayField(base_field=models.IntegerField())
|
||||||
|
members_as_text = ArrayField(base_field=models.CharField(max_length=255))
|
||||||
|
|
||||||
|
user = User()
|
||||||
|
reveal_type(user.members) # E: Revealed type is 'builtins.list[builtins.int*]'
|
||||||
|
reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list[builtins.str*]'
|
||||||
|
|
||||||
74
test/pytest_tests/test_to_attr_as_string.py
Normal file
74
test/pytest_tests/test_to_attr_as_string.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from test.pytest_plugin import file, reveal_type, env
|
||||||
|
from test.pytest_tests.base import BaseDjangoPluginTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestForeignKey(BaseDjangoPluginTestCase):
|
||||||
|
@env(DJANGO_SETTINGS_MODULE='mysettings')
|
||||||
|
def _test_to_parameter_could_be_specified_as_string(self):
|
||||||
|
from apps.myapp.models import Publisher
|
||||||
|
|
||||||
|
publisher = Publisher()
|
||||||
|
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[apps.myapp2.models.Book]'
|
||||||
|
|
||||||
|
# @env(DJANGO_SETTINGS_MODULE='mysettings')
|
||||||
|
# def _test_creates_underscore_id_attr(self):
|
||||||
|
# from apps.myapp2.models import Book
|
||||||
|
#
|
||||||
|
# book = Book()
|
||||||
|
# reveal_type(book.publisher) # E: Revealed type is 'apps.myapp.models.Publisher'
|
||||||
|
# reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
|
||||||
|
|
||||||
|
@file('mysettings.py')
|
||||||
|
def mysettings(self):
|
||||||
|
SECRET_KEY = '112233'
|
||||||
|
ROOT_DIR = '<TMP>'
|
||||||
|
APPS_DIR = '<TMP>/apps'
|
||||||
|
|
||||||
|
INSTALLED_APPS = ('apps.myapp', 'apps.myapp2')
|
||||||
|
|
||||||
|
@file('apps/myapp/models.py', make_parent_packages=True)
|
||||||
|
def apps_myapp_models(self):
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class Publisher(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@file('apps/myapp2/models.py', make_parent_packages=True)
|
||||||
|
def apps_myapp2_models(self):
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class Book(models.Model):
|
||||||
|
publisher = models.ForeignKey(to='myapp.Publisher', on_delete=models.CASCADE,
|
||||||
|
related_name='books')
|
||||||
|
|
||||||
|
|
||||||
|
class TestOneToOneField(BaseDjangoPluginTestCase):
|
||||||
|
@env(DJANGO_SETTINGS_MODULE='mysettings')
|
||||||
|
def test_to_parameter_could_be_specified_as_string(self):
|
||||||
|
from apps.myapp.models import User
|
||||||
|
|
||||||
|
user = User()
|
||||||
|
reveal_type(user.profile) # E: Revealed type is 'apps.myapp2.models.Profile'
|
||||||
|
|
||||||
|
@file('mysettings.py')
|
||||||
|
def mysettings(self):
|
||||||
|
SECRET_KEY = '112233'
|
||||||
|
ROOT_DIR = '<TMP>'
|
||||||
|
APPS_DIR = '<TMP>/apps'
|
||||||
|
|
||||||
|
INSTALLED_APPS = ('apps.myapp', 'apps.myapp2')
|
||||||
|
|
||||||
|
@file('apps/myapp/models.py', make_parent_packages=True)
|
||||||
|
def apps_myapp_models(self):
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class User(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@file('apps/myapp2/models.py', make_parent_packages=True)
|
||||||
|
def apps_myapp2_models(self):
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class Profile(models.Model):
|
||||||
|
user = models.OneToOneField(to='myapp.User', on_delete=models.CASCADE,
|
||||||
|
related_name='profile')
|
||||||
@@ -14,11 +14,14 @@ MYPY_INI_PATH = ROOT_DIR / 'test' / 'plugins.ini'
|
|||||||
|
|
||||||
class DjangoTestSuite(DataSuite):
|
class DjangoTestSuite(DataSuite):
|
||||||
files = [
|
files = [
|
||||||
'check-objects-queryset.test',
|
# 'check-objects-queryset.test',
|
||||||
'check-model-fields.test',
|
# 'check-model-fields.test',
|
||||||
'check-postgres-fields.test',
|
# 'check-postgres-fields.test',
|
||||||
'check-model-relations.test',
|
# 'check-model-relations.test',
|
||||||
'check-parse-settings.test'
|
# 'check-parse-settings.test',
|
||||||
|
# 'check-to-attr-as-string-one-to-one-field.test',
|
||||||
|
'check-to-attr-as-string-foreign-key.test',
|
||||||
|
# 'check-foreign-key-as-string-creates-underscore-id-attr.test'
|
||||||
]
|
]
|
||||||
data_prefix = str(TEST_DATA_DIR)
|
data_prefix = str(TEST_DATA_DIR)
|
||||||
|
|
||||||
|
|||||||
43
test/vistir.py
Normal file
43
test/vistir.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# Borrowed from Pew.
|
||||||
|
# See https://github.com/berdario/pew/blob/master/pew/_utils.py#L82
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from decorator import contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def temp_environ():
|
||||||
|
"""Allow the ability to set os.environ temporarily"""
|
||||||
|
environ = dict(os.environ)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
os.environ.clear()
|
||||||
|
os.environ.update(environ)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def temp_path():
|
||||||
|
"""A context manager which allows the ability to set sys.path temporarily"""
|
||||||
|
path = [p for p in sys.path]
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
sys.path = [p for p in path]
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def cd(path):
|
||||||
|
"""Context manager to temporarily change working directories"""
|
||||||
|
if not path:
|
||||||
|
return
|
||||||
|
prev_cwd = Path.cwd().as_posix()
|
||||||
|
if isinstance(path, Path):
|
||||||
|
path = path.as_posix()
|
||||||
|
os.chdir(str(path))
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
os.chdir(prev_cwd)
|
||||||
Reference in New Issue
Block a user