latest changes

This commit is contained in:
Maxim Kurnikov
2018-11-26 23:58:34 +03:00
parent 348efcd371
commit f59cfe6371
34 changed files with 1558 additions and 132 deletions

1
.gitignore vendored
View File

@@ -4,3 +4,4 @@ out/
/test_sqlite.py
/django
.idea/
.mypy_cache/

View File

@@ -1,3 +1,3 @@
pytest_plugins = [
'test.data'
'test.pytest_plugin'
]

View File

@@ -0,0 +1,7 @@
from .config import (
AppConfig as AppConfig
)
from .registry import (
apps as apps
)

View File

@@ -0,0 +1,26 @@
from typing import Any, Iterator, Type
from django.db.models.base import Model
MODELS_MODULE_NAME: str
class AppConfig:
name: str = ...
module: Any = ...
apps: None = ...
label: str = ...
verbose_name: str = ...
path: str = ...
models_module: None = ...
models: None = ...
def __init__(self, app_name: str, app_module: None) -> None: ...
@classmethod
def create(cls, entry: str) -> AppConfig: ...
def get_model(
self, model_name: str, require_ready: bool = ...
) -> Type[Model]: ...
def get_models(
self, include_auto_created: bool = ..., include_swapped: bool = ...
) -> Iterator[Type[Model]]: ...
def import_models(self) -> None: ...
def ready(self) -> None: ...

View File

@@ -0,0 +1,62 @@
import collections
from typing import Any, Callable, List, Optional, Tuple, Type, Union, Iterable
from django.apps.config import AppConfig
from django.db.migrations.state import AppConfigStub
from django.db.models.base import Model
from .config import AppConfig
class Apps:
all_models: collections.defaultdict = ...
app_configs: collections.OrderedDict = ...
stored_app_configs: List[Any] = ...
apps_ready: bool = ...
loading: bool = ...
def __init__(
self,
installed_apps: Optional[
Union[List[AppConfigStub], List[str], Tuple]
] = ...,
) -> None: ...
models_ready: bool = ...
ready: bool = ...
def populate(
self, installed_apps: Union[List[AppConfigStub], List[str], Tuple] = ...
) -> None: ...
def check_apps_ready(self) -> None: ...
def check_models_ready(self) -> None: ...
def get_app_configs(self) -> Iterable[AppConfig]: ...
def get_app_config(self, app_label: str) -> AppConfig: ...
def get_models(
self, include_auto_created: bool = ..., include_swapped: bool = ...
) -> List[Type[Model]]: ...
def get_model(
self,
app_label: str,
model_name: Optional[str] = ...,
require_ready: bool = ...,
) -> Type[Model]: ...
def register_model(self, app_label: str, model: Type[Model]) -> None: ...
def is_installed(self, app_name: str) -> bool: ...
def get_containing_app_config(
self, object_name: str
) -> Optional[AppConfig]: ...
def get_registered_model(
self, app_label: str, model_name: str
) -> Type[Model]: ...
def get_swappable_settings_name(self, to_string: str) -> Optional[str]: ...
def set_available_apps(self, available: List[str]) -> None: ...
def unset_available_apps(self) -> None: ...
def set_installed_apps(
self, installed: Union[List[str], Tuple[str]]
) -> None: ...
def unset_installed_apps(self) -> None: ...
def clear_cache(self) -> None: ...
def lazy_model_operation(
self, function: Callable, *model_keys: Any
) -> None: ...
def do_pending_operations(self, model: Type[Model]) -> None: ...
apps: Apps

View File

@@ -1,2 +1,18 @@
from typing import Any
from .utils import (ProgrammingError as ProgrammingError,
IntegrityError as IntegrityError)
IntegrityError as IntegrityError,
OperationalError as OperationalError,
DatabaseError as DatabaseError,
DataError as DataError,
NotSupportedError as NotSupportedError)
connections: Any
router: Any
class DefaultConnectionProxy:
def __getattr__(self, item: str) -> Any: ...
def __setattr__(self, name: str, value: Any) -> None: ...
def __delattr__(self, name: str) -> None: ...
connection: Any

View File

@@ -3,14 +3,21 @@ from .base import Model as Model
from .fields import (AutoField as AutoField,
IntegerField as IntegerField,
SmallIntegerField as SmallIntegerField,
BigIntegerField as BigIntegerField,
CharField as CharField,
Field as Field,
SlugField as SlugField,
TextField as TextField,
BooleanField as BooleanField)
BooleanField as BooleanField,
FileField as FileField,
DateField as DateField,
DateTimeField as DateTimeField,
IPAddressField as IPAddressField,
GenericIPAddressField as GenericIPAddressField)
from .fields.related import (ForeignKey as ForeignKey,
OneToOneField as OneToOneField)
OneToOneField as OneToOneField,
ManyToManyField as ManyToManyField)
from .deletion import (CASCADE as CASCADE,
SET_DEFAULT as SET_DEFAULT,
@@ -24,4 +31,11 @@ from .query_utils import Q as Q
from .lookups import Lookup as Lookup
from .expressions import F as F
from .expressions import (F as F,
Subquery as Subquery,
Exists as Exists,
OrderBy as OrderBy,
OuterRef as OuterRef)
from .manager import (BaseManager as BaseManager,
Manager as Manager)

View File

@@ -3,6 +3,7 @@ from datetime import datetime, timedelta
from typing import (Any, Callable, Dict, Iterator, List, Optional, Set, Tuple,
Type, Union)
from django.db.models import QuerySet
from django.db.models.fields import Field
from django.db.models.lookups import Lookup
from django.db.models.sql.compiler import SQLCompiler
@@ -180,4 +181,50 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
class F(Combinable):
name: str
def __init__(self, name: str): ...
def resolve_expression(
self,
query: Any = ...,
allow_joins: bool = ...,
reuse: Optional[Set[str]] = ...,
summarize: bool = ...,
for_save: bool = ...,
) -> Expression: ...
class OuterRef(F): ...
class Subquery(Expression):
template: str = ...
queryset: QuerySet = ...
extra: Dict[Any, Any] = ...
def __init__(
self,
queryset: QuerySet,
output_field: Optional[Field] = ...,
**extra: Any
) -> None: ...
class Exists(Subquery):
extra: Dict[Any, Any]
template: str = ...
negated: bool = ...
def __init__(
self, *args: Any, negated: bool = ..., **kwargs: Any
) -> None: ...
def __invert__(self) -> Exists: ...
class OrderBy(BaseExpression):
template: str = ...
nulls_first: bool = ...
nulls_last: bool = ...
descending: bool = ...
expression: Expression = ...
def __init__(
self,
expression: Combinable,
descending: bool = ...,
nulls_first: bool = ...,
nulls_last: bool = ...,
) -> None: ...

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional
from django.db.models.query_utils import RegisterLookupMixin
@@ -7,6 +7,7 @@ class Field(RegisterLookupMixin):
def __init__(self,
primary_key: bool = False,
**kwargs): ...
def __get__(self, instance, owner) -> Any: ...
@@ -14,8 +15,10 @@ class IntegerField(Field):
def __get__(self, instance, owner) -> int: ...
class SmallIntegerField(IntegerField):
pass
class SmallIntegerField(IntegerField): ...
class BigIntegerField(IntegerField): ...
class AutoField(Field):
@@ -26,11 +29,11 @@ class CharField(Field):
def __init__(self,
max_length: int,
**kwargs): ...
def __get__(self, instance, owner) -> str: ...
class SlugField(CharField):
pass
class SlugField(CharField): ...
class TextField(Field):
@@ -39,3 +42,29 @@ class TextField(Field):
class BooleanField(Field):
def __get__(self, instance, owner) -> bool: ...
class FileField(Field): ...
class IPAddressField(Field): ...
class GenericIPAddressField(Field):
default_error_messages: Any = ...
unpack_ipv4: Any = ...
protocol: Any = ...
def __init__(
self,
verbose_name: Optional[Any] = ...,
name: Optional[Any] = ...,
protocol: str = ...,
unpack_ipv4: bool = ...,
*args: Any,
**kwargs: Any
) -> None: ...
class DateField(Field): ...
class DateTimeField(DateField): ...

View File

@@ -22,3 +22,12 @@ class OneToOneField(Field, Generic[_T]):
related_name: str = ...,
**kwargs): ...
def __get__(self, instance, owner) -> _T: ...
class ManyToManyField(Field, Generic[_T]):
def __init__(self,
to: Union[Type[_T], str],
on_delete: Any,
related_name: str = ...,
**kwargs): ...
def __get__(self, instance, owner) -> _T: ...

View File

@@ -0,0 +1,140 @@
from collections import OrderedDict
from datetime import datetime
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple, Type, Union, TypeVar, Set, Generic, Iterator
from unittest.mock import MagicMock
from django.db.models import Q
from django.db.models.base import Model
from django.db.models.query import QuerySet, RawQuerySet
_T = TypeVar('_T', bound=Model)
class BaseManager:
creation_counter: int = ...
auto_created: bool = ...
use_in_migrations: bool = ...
def __new__(cls: Type[BaseManager], *args: Any, **kwargs: Any) -> BaseManager: ...
model: Any = ...
name: Any = ...
def __init__(self) -> None: ...
def deconstruct(self) -> Tuple[bool, str, None, Tuple, Dict[str, int]]: ...
def check(self, **kwargs: Any) -> List[Any]: ...
@classmethod
def from_queryset(
cls, queryset_class: Any, class_name: Optional[Any] = ...
): ...
def contribute_to_class(self, model: Type[Model], name: str) -> None: ...
def db_manager(
self,
using: Optional[str] = ...,
hints: Optional[Dict[str, Model]] = ...,
) -> Manager: ...
@property
def db(self) -> str: ...
def get_queryset(self) -> QuerySet: ...
def all(self) -> QuerySet: ...
def __eq__(self, other: Optional[Any]) -> bool: ...
def __hash__(self): ...
class Manager(Generic[_T]):
def exists(self) -> bool: ...
def explain(
self, *, format: Optional[Any] = ..., **options: Any
) -> str: ...
def raw(
self,
raw_query: str,
params: Optional[
Union[
Dict[str, str],
List[datetime],
List[Decimal],
List[str],
Set[str],
Tuple[int],
]
] = ...,
translations: Optional[Dict[str, str]] = ...,
using: None = ...,
) -> RawQuerySet: ...
def values(self, *fields: Any, **expressions: Any) -> QuerySet: ...
def values_list(
self, *fields: Any, flat: bool = ..., named: bool = ...
) -> QuerySet: ...
def dates(
self, field_name: str, kind: str, order: str = ...
) -> QuerySet: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...
) -> QuerySet: ...
def none(self) -> QuerySet[_T]: ...
def all(self) -> QuerySet[_T]: ...
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def complex_filter(
self,
filter_obj: Union[
Dict[str, datetime], Dict[str, QuerySet], Q, MagicMock
],
) -> QuerySet[_T]: ...
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
def select_for_update(
self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...
) -> QuerySet: ...
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
def extra(
self,
select: Optional[
Union[Dict[str, int], Dict[str, str], OrderedDict]
] = ...,
where: Optional[List[str]] = ...,
params: Optional[Union[List[int], List[str]]] = ...,
tables: Optional[List[str]] = ...,
order_by: Optional[Union[List[str], Tuple[str]]] = ...,
select_params: Optional[Union[List[int], List[str], Tuple[int]]] = ...,
) -> QuerySet[_T]: ...
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
def aggregate(
self, *args: Any, **kwargs: Any
) -> Dict[str, Optional[Union[datetime, float]]]: ...
def count(self) -> int: ...
def get(
self, *args: Any, **kwargs: Any
) -> _T: ...
def create(self, **kwargs: Any) -> _T: ...
class ManagerDescriptor:
manager: Manager = ...
def __init__(self, manager: Manager) -> None: ...
def __get__(
self, instance: Optional[Model], cls: Type[Model] = ...
) -> Manager: ...
class EmptyManager(Manager):
creation_counter: int
name: None
model: Optional[Type[Model]] = ...
def __init__(self, model: Type[Model]) -> None: ...
def get_queryset(self) -> QuerySet: ...

View File

@@ -7,3 +7,8 @@ from .conf import (include as include,
from .resolvers import (ResolverMatch as ResolverMatch,
get_ns_resolver as get_ns_resolver,
get_resolver as get_resolver)
# noinspection PyUnresolvedReferences
from .converters import (
register_converter as register_converter
)

View File

@@ -0,0 +1,30 @@
from typing import Any, Tuple, Union
BASE2_ALPHABET: str
BASE16_ALPHABET: str
BASE56_ALPHABET: str
BASE36_ALPHABET: str
BASE62_ALPHABET: str
BASE64_ALPHABET: Any
class BaseConverter:
decimal_digits: str = ...
sign: str = ...
digits: str = ...
def __init__(self, digits: str, sign: str = ...) -> None: ...
def encode(self, i: int) -> str: ...
def decode(self, s: str) -> int: ...
def convert(
self,
number: Union[int, str],
from_digits: str,
to_digits: str,
sign: str,
) -> Tuple[int, str]: ...
base2: Any
base16: Any
base36: Any
base56: Any
base62: Any
base64: Any

View File

@@ -1,15 +1,15 @@
from typing import Dict, Optional, NamedTuple
import typing
from typing import Dict, Optional, NamedTuple, Any
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Type
from mypy.nodes import SymbolTableNode, Var, Expression
from mypy.plugin import FunctionContext
from mypy.types import Instance, UnionType, NoneTyp
from mypy.types import Type, Instance, UnionType, NoneTyp
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
def create_new_symtable_node(name: str, kind: int, instance: Instance) -> SymbolTableNode:
@@ -26,12 +26,10 @@ Argument = NamedTuple('Argument', fields=[
def get_call_signature_or_none(ctx: FunctionContext) -> Optional[Dict[str, Argument]]:
arg_names = ctx.context.arg_names
result: Dict[str, Argument] = {}
positional_args_only = []
positional_arg_types_only = []
for arg, arg_name, arg_type in zip(ctx.args, arg_names, ctx.arg_types):
for arg, arg_name, arg_type in zip(ctx.args, ctx.arg_names, ctx.arg_types):
if arg_name is None:
positional_args_only.append(arg)
positional_arg_types_only.append(arg_type)
@@ -65,3 +63,7 @@ def make_required(typ: Type) -> Type:
return typ
items = [item for item in typ.items if not isinstance(item, NoneTyp)]
return UnionType.make_union(items)
def get_obj_type_name(typ: typing.Type) -> str:
return typ.__module__ + '.' + typ.__qualname__

View File

@@ -1,46 +1,79 @@
import os
from typing import Callable, Optional
from typing import Callable, Optional, List
from django.apps.registry import Apps
from django.conf import Settings
from mypy import build
from mypy.build import BuildManager
from mypy.options import Options
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
from mypy.types import Type
from mypy_django_plugin import helpers
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.plugins.objects_queryset import set_objects_queryset_to_model_class
from mypy_django_plugin.plugins.postgres_fields import determine_type_of_array_field
from mypy_django_plugin.plugins.related_fields import set_related_name_instance_for_onetoonefield, \
set_related_name_manager_for_foreign_key, set_fieldname_attrs_for_related_fields
from mypy_django_plugin.plugins.related_fields import OneToOneFieldHook, \
ForeignKeyHook, set_fieldname_attrs_for_related_fields
from mypy_django_plugin.plugins.setup_settings import DjangoConfSettingsInitializerHook
base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
def transform_model_class(ctx: ClassDefContext) -> None:
class TransformModelClassHook(object):
def __init__(self, settings: Settings, apps: Apps):
self.settings = settings
self.apps = apps
def __call__(self, ctx: ClassDefContext) -> None:
base_model_classes.add(ctx.cls.fullname)
set_fieldname_attrs_for_related_fields(ctx)
set_objects_queryset_to_model_class(ctx)
def always_return_none(manager: BuildManager):
return None
build.read_plugins_snapshot = always_return_none
class DjangoPlugin(Plugin):
def __init__(self,
options: Options) -> None:
super().__init__(options)
self.django_settings = None
self.apps = None
monkeypatch.replace_apply_function_plugin_method()
django_settings_module = os.environ.get('DJANGO_SETTINGS_MODULE')
if django_settings_module:
self.django_settings = Settings(django_settings_module)
# import django
# django.setup()
#
# from django.apps import apps
# self.apps = apps
#
# models_modules = []
# for app_config in self.apps.app_configs.values():
# models_modules.append(app_config.module.__name__ + '.' + 'models')
#
# monkeypatch.state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(django_settings_module,
# models_modules)
monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(django_settings_module)
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == helpers.FOREIGN_KEY_FULLNAME:
return set_related_name_manager_for_foreign_key
return ForeignKeyHook(settings=self.django_settings,
apps=self.apps)
if fullname == helpers.ONETOONE_FIELD_FULLNAME:
return set_related_name_instance_for_onetoonefield
return OneToOneFieldHook(settings=self.django_settings,
apps=self.apps)
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field
@@ -49,9 +82,11 @@ class DjangoPlugin(Plugin):
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in base_model_classes:
return transform_model_class
if fullname == 'django.conf._DjangoConfLazyObject':
return TransformModelClassHook(self.django_settings, self.apps)
if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS:
return DjangoConfSettingsInitializerHook(settings=self.django_settings)
return None

View File

@@ -0,0 +1,112 @@
from typing import Optional, List, Sequence
from mypy.build import BuildManager, Graph, State
from mypy.modulefinder import BuildSource
from mypy.nodes import Expression, Context
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import Type, CallableType, Instance
def state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(settings_module: str,
models_py_modules: List[str]):
from mypy.build import State
old_compute_dependencies = State.compute_dependencies
def patched_compute_dependencies(self: State):
old_compute_dependencies(self)
if self.id == settings_module:
self.dependencies.extend(models_py_modules)
State.compute_dependencies = patched_compute_dependencies
def load_graph_to_add_settings_file_as_a_source_seed(settings_module: str):
from mypy import build
old_load_graph = build.load_graph
def patched_load_graph(sources: List[BuildSource], manager: BuildManager,
old_graph: Optional[Graph] = None,
new_modules: Optional[List[State]] = None):
if all([source.module != settings_module for source in sources]):
sources.append(BuildSource(None, settings_module, None))
return old_load_graph(sources=sources, manager=manager,
old_graph=old_graph,
new_modules=new_modules)
build.load_graph = patched_load_graph
def replace_apply_function_plugin_method():
def apply_function_plugin(self,
arg_types: List[Type],
inferred_ret_type: Type,
arg_names: Optional[Sequence[Optional[str]]],
formal_to_actual: List[List[int]],
args: List[Expression],
num_formals: int,
fullname: str,
object_type: Optional[Type],
context: Context) -> Type:
"""Use special case logic to infer the return type of a specific named function/method.
Caller must ensure that a plugin hook exists. There are two different cases:
- If object_type is None, the caller must ensure that a function hook exists
for fullname.
- If object_type is not None, the caller must ensure that a method hook exists
for fullname.
Return the inferred return type.
"""
formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]]
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
formal_arg_names = [None for _ in range(num_formals)] # type: List[Optional[str]]
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_types[formal].append(arg_types[actual])
formal_arg_exprs[formal].append(args[actual])
if arg_names:
formal_arg_names[formal] = arg_names[actual]
num_passed_positionals = sum([1 if name is None else 0
for name in formal_arg_names])
if arg_names and num_passed_positionals > 0:
object_type_info = None
if object_type is not None:
if isinstance(object_type, CallableType):
# class object, convert to corresponding Instance
object_type = object_type.ret_type
if isinstance(object_type, Instance):
# skip TypedDictType and others
object_type_info = object_type.type
defn_arg_names = self._get_defn_arg_names(fullname, object_type=object_type_info)
if defn_arg_names:
if num_formals < len(defn_arg_names):
# self/cls argument has been passed implicitly
defn_arg_names = defn_arg_names[1:]
formal_arg_names[:num_passed_positionals] = defn_arg_names[:num_passed_positionals]
if object_type is None:
# Apply function plugin
callback = self.plugin.get_function_hook(fullname)
assert callback is not None # Assume that caller ensures this
return callback(
FunctionContext(formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
else:
# Apply method plugin
method_callback = self.plugin.get_method_hook(fullname)
assert method_callback is not None # Assume that caller ensures this
return method_callback(
MethodContext(object_type, formal_arg_names, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))
from mypy.checkexpr import ExpressionChecker
ExpressionChecker.apply_function_plugin = apply_function_plugin

View File

@@ -9,19 +9,28 @@ from mypy_django_plugin import helpers
def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None:
if 'objects' in ctx.cls.info.names:
return
api = cast(SemanticAnalyzerPass2, ctx.api)
# search over mro
objects_sym = ctx.cls.info.get('objects')
if objects_sym is not None:
return None
metaclass_node = ctx.cls.info.names.get('Meta')
if metaclass_node is not None:
for stmt in metaclass_node.node.defn.defs.body:
# only direct Meta class
metaclass_sym = ctx.cls.info.names.get('Meta')
# skip if abstract
if metaclass_sym is not None:
for stmt in metaclass_sym.node.defn.defs.body:
if (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1
and stmt.lvalues[0].name == 'abstract'):
is_abstract = api.parse_bool(stmt.rvalue)
is_abstract = ctx.api.parse_bool(stmt.rvalue)
if is_abstract:
return
return None
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, args=[Instance(ctx.cls.info, [])])
new_objects_node = helpers.create_new_symtable_node('objects', MDEF, instance=typ)
ctx.cls.info.names['objects'] = new_objects_node
api = cast(SemanticAnalyzerPass2, ctx.api)
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME,
args=[Instance(ctx.cls.info, [])])
if not typ:
return None
ctx.cls.info.names['objects'] = helpers.create_new_symtable_node('objects',
kind=MDEF,
instance=typ)

View File

@@ -1,14 +1,11 @@
from mypy.plugin import FunctionContext
from mypy.types import Type
from mypy_django_plugin import helpers
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
signature = helpers.get_call_signature_or_none(ctx)
if signature is None:
if 'base_field' not in ctx.arg_names:
return ctx.default_return_type
_, base_field_arg_type = signature['base_field']
base_field_arg_type = ctx.arg_types[ctx.arg_names.index('base_field')][0]
return ctx.api.named_generic_type(ctx.context.callee.fullname,
args=[base_field_arg_type.type.names['__get__'].type.ret_type])

View File

@@ -1,23 +1,63 @@
from typing import Optional, cast
import typing
from typing import Optional, cast, Tuple, Any
from django.apps.registry import Apps
from django.conf import Settings
from django.db import models
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, MemberExpr
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, StrExpr
from mypy.plugin import FunctionContext, ClassDefContext
from mypy.types import Type, CallableType, Instance, AnyType
from mypy_django_plugin import helpers
def extract_to_value_type(ctx: FunctionContext) -> Optional[Instance]:
signature = helpers.get_call_signature_or_none(ctx)
if signature is None or 'to' not in signature:
return None
def get_instance_type_for_class(klass: typing.Type[models.Model],
api: TypeChecker) -> Optional[Instance]:
model_qualname = helpers.get_obj_type_name(klass)
module_name, _, class_name = model_qualname.rpartition('.')
module = api.modules.get(module_name)
if not module or class_name not in module.names:
return
arg, arg_type = signature['to']
if not isinstance(arg_type, CallableType):
return None
sym = module.names[class_name]
return Instance(sym.node, [])
return arg_type.ret_type
def extract_to_value_type(ctx: FunctionContext,
apps: Optional[Apps]) -> Tuple[Optional[Instance], bool]:
api = cast(TypeChecker, ctx.api)
if 'to' not in ctx.arg_names:
return None, False
arg = ctx.args[ctx.arg_names.index('to')][0]
arg_type = ctx.arg_types[ctx.arg_names.index('to')][0]
if isinstance(arg_type, CallableType):
return arg_type.ret_type, False
if apps:
if isinstance(arg, StrExpr):
arg_value = arg.value
if '.' not in arg_value:
return None, False
app_label, modelname = arg_value.lower().split('.')
try:
model_cls = apps.get_model(app_label, modelname)
except LookupError:
# no model class found
return None, False
try:
instance = get_instance_type_for_class(model_cls, api=api)
if not instance:
return None, False
return instance, True
except AssertionError:
pass
return None, False
def extract_related_name_value(ctx: FunctionContext) -> str:
@@ -30,43 +70,56 @@ def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instanc
instance=new_member_instance)
def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
class ForeignKeyHook(object):
def __init__(self, settings: Settings, apps: Apps):
self.settings = settings
self.apps = apps
def __call__(self, ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
outer_class_info = api.tscope.classes[-1]
if 'related_name' not in ctx.context.arg_names:
return ctx.default_return_type
referred_to = extract_to_value_type(ctx)
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps)
if not referred_to:
return ctx.default_return_type
if 'related_name' in ctx.context.arg_names:
related_name = extract_related_name_value(ctx)
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
args=[Instance(outer_class_info, [])])
if isinstance(referred_to, AnyType):
# referred_to defined as string, which is unsupported for now
return ctx.default_return_type
add_new_class_member(referred_to.type,
related_name, queryset_type)
if is_string_based:
return referred_to
return ctx.default_return_type
def set_related_name_instance_for_onetoonefield(ctx: FunctionContext) -> Type:
class OneToOneFieldHook(object):
def __init__(self, settings: Optional[Settings], apps: Optional[Apps]):
self.settings = settings
self.apps = apps
def __call__(self, ctx: FunctionContext) -> Type:
if 'related_name' not in ctx.context.arg_names:
return ctx.default_return_type
referred_to = extract_to_value_type(ctx)
referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps)
if referred_to is None:
return ctx.default_return_type
if 'related_name' in ctx.context.arg_names:
related_name = extract_related_name_value(ctx)
outer_class_info = ctx.api.tscope.classes[-1]
api = cast(TypeChecker, ctx.api)
add_new_class_member(referred_to.type, related_name,
new_member_instance=api.named_type(outer_class_info.fullname()))
new_member_instance=Instance(outer_class_info, []))
if is_string_based:
return referred_to
return ctx.default_return_type

View File

@@ -1,7 +1,7 @@
from typing import cast, Any
from typing import cast
from django.conf import Settings
from mypy.nodes import MDEF, TypeInfo, SymbolTable
from mypy.nodes import MDEF
from mypy.plugin import ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Instance, AnyType, TypeOfAny
@@ -9,24 +9,22 @@ from mypy.types import Instance, AnyType, TypeOfAny
from mypy_django_plugin import helpers
def get_obj_type_name(value: Any) -> str:
return type(value).__module__ + '.' + type(value).__qualname__
class DjangoConfSettingsInitializerHook(object):
def __init__(self, settings: Settings):
self.settings = settings
def __call__(self, ctx: ClassDefContext) -> None:
api = cast(SemanticAnalyzerPass2, ctx.api)
if self.settings:
for name, value in self.settings.__dict__.items():
if name.isupper():
if value is None:
# TODO: change to Optional[Any] later
ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF,
instance=api.builtin_type('builtins.object'))
continue
type_fullname = get_obj_type_name(value)
type_fullname = helpers.get_obj_type_name(type(value))
sym = api.lookup_fully_qualified_or_none(type_fullname)
if sym is not None:
args = len(sym.node.type_vars) * [AnyType(TypeOfAny.from_omitted_generics)]

View File

@@ -5,3 +5,4 @@ python_files = test*.py
addopts =
--tb=native
-v
-s

View File

@@ -147,6 +147,12 @@ class DjangoDataDrivenTestCase(DataDrivenTestCase):
self.old_cwd = os.getcwd()
self.tmpdir = tempfile.TemporaryDirectory(prefix='mypy-test-')
tmpdir_root = os.path.join(self.tmpdir.name, 'tmp')
new_files = []
for path, contents in self.files:
new_files.append((path, contents.replace('<TMP>', tmpdir_root)))
self.files = new_files
os.chdir(self.tmpdir.name)
os.mkdir(test_temp_dir)
@@ -179,7 +185,7 @@ class DjangoDataDrivenTestCase(DataDrivenTestCase):
self.clean_up.append((True, d))
self.clean_up.append((False, path))
sys.path.insert(0, os.path.join(self.tmpdir.name, 'tmp'))
sys.path.insert(0, tmpdir_root)
def teardown(self):
if hasattr(self, 'old_environ'):

253
test/helpers.py Normal file
View File

@@ -0,0 +1,253 @@
import inspect
import os
import re
from typing import List, Callable, Optional, Tuple
import pytest # type: ignore # no pytest in typeshed
skip = pytest.mark.skip
# AssertStringArraysEqual displays special line alignment helper messages if
# the first different line has at least this many characters,
MIN_LINE_LENGTH_FOR_ALIGNMENT = 5
class TypecheckAssertionError(AssertionError):
def __init__(self, error_message: str, lineno: int):
self.error_message = error_message
self.lineno = lineno
def first_line(self):
return self.__class__.__name__ + '(message="Invalid output")'
def __str__(self):
return self.error_message
def _clean_up(a: List[str]) -> List[str]:
"""Remove common directory prefix from all strings in a.
This uses a naive string replace; it seems to work well enough. Also
remove trailing carriage returns.
"""
res = []
for s in a:
prefix = os.sep
ss = s
for p in prefix, prefix.replace(os.sep, '/'):
if p != '/' and p != '//' and p != '\\' and p != '\\\\':
ss = ss.replace(p, '')
# Ignore spaces at end of line.
ss = re.sub(' +$', '', ss)
res.append(re.sub('\\r$', '', ss))
return res
def _num_skipped_prefix_lines(a1: List[str], a2: List[str]) -> int:
num_eq = 0
while num_eq < min(len(a1), len(a2)) and a1[num_eq] == a2[num_eq]:
num_eq += 1
return max(0, num_eq - 4)
def _num_skipped_suffix_lines(a1: List[str], a2: List[str]) -> int:
num_eq = 0
while (num_eq < min(len(a1), len(a2))
and a1[-num_eq - 1] == a2[-num_eq - 1]):
num_eq += 1
return max(0, num_eq - 4)
def _add_aligned_message(s1: str, s2: str, error_message: str) -> str:
"""Align s1 and s2 so that the their first difference is highlighted.
For example, if s1 is 'foobar' and s2 is 'fobar', display the
following lines:
E: foobar
A: fobar
^
If s1 and s2 are long, only display a fragment of the strings around the
first difference. If s1 is very short, do nothing.
"""
# Seeing what went wrong is trivial even without alignment if the expected
# string is very short. In this case do nothing to simplify output.
if len(s1) < 4:
return error_message
maxw = 72 # Maximum number of characters shown
error_message += 'Alignment of first line difference:\n'
# sys.stderr.write('Alignment of first line difference:\n')
trunc = False
while s1[:30] == s2[:30]:
s1 = s1[10:]
s2 = s2[10:]
trunc = True
if trunc:
s1 = '...' + s1
s2 = '...' + s2
max_len = max(len(s1), len(s2))
extra = ''
if max_len > maxw:
extra = '...'
# Write a chunk of both lines, aligned.
error_message += ' E: {}{}\n'.format(s1[:maxw], extra)
# sys.stderr.write(' E: {}{}\n'.format(s1[:maxw], extra))
error_message += ' A: {}{}\n'.format(s2[:maxw], extra)
# sys.stderr.write(' A: {}{}\n'.format(s2[:maxw], extra))
# Write an indicator character under the different columns.
error_message += ' '
# sys.stderr.write(' ')
for j in range(min(maxw, max(len(s1), len(s2)))):
if s1[j:j + 1] != s2[j:j + 1]:
error_message += '^'
# sys.stderr.write('^') # Difference
break
else:
error_message += ' '
# sys.stderr.write(' ') # Equal
error_message += '\n'
return error_message
# sys.stderr.write('\n')
def assert_string_arrays_equal(expected: List[str], actual: List[str]) -> None:
"""Assert that two string arrays are equal.
Display any differences in a human-readable form.
"""
actual = _clean_up(actual)
error_message = ''
if actual != expected:
num_skip_start = _num_skipped_prefix_lines(expected, actual)
num_skip_end = _num_skipped_suffix_lines(expected, actual)
error_message += 'Expected:\n'
# sys.stderr.write('Expected:\n')
# If omit some lines at the beginning, indicate it by displaying a line
# with '...'.
if num_skip_start > 0:
error_message += ' ...\n'
# sys.stderr.write(' ...\n')
# Keep track of the first different line.
first_diff = -1
# Display only this many first characters of identical lines.
width = 75
for i in range(num_skip_start, len(expected) - num_skip_end):
if i >= len(actual) or expected[i] != actual[i]:
if first_diff < 0:
first_diff = i
error_message += ' {:<45} (diff)'.format(expected[i])
# sys.stderr.write(' {:<45} (diff)'.format(expected[i]))
else:
e = expected[i]
error_message += ' ' + e[:width]
# sys.stderr.write(' ' + e[:width])
if len(e) > width:
error_message += '...'
# sys.stderr.write('...')
error_message += '\n'
# sys.stderr.write('\n')
if num_skip_end > 0:
error_message += ' ...\n'
# sys.stderr.write(' ...\n')
error_message += 'Actual:\n'
# sys.stderr.write('Actual:\n')
if num_skip_start > 0:
error_message += ' ...\n'
# sys.stderr.write(' ...\n')
for j in range(num_skip_start, len(actual) - num_skip_end):
if j >= len(expected) or expected[j] != actual[j]:
error_message += ' {:<45} (diff)'.format(actual[j])
# sys.stderr.write(' {:<45} (diff)'.format(actual[j]))
else:
a = actual[j]
error_message += ' ' + a[:width]
# sys.stderr.write(' ' + a[:width])
if len(a) > width:
error_message += '...'
# sys.stderr.write('...')
error_message += '\n'
# sys.stderr.write('\n')
if actual == []:
error_message += ' (empty)\n'
# sys.stderr.write(' (empty)\n')
if num_skip_end > 0:
error_message += ' ...\n'
# sys.stderr.write(' ...\n')
error_message += '\n'
# sys.stderr.write('\n')
if first_diff >= 0 and first_diff < len(actual) and (
len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT):
# Display message that helps visualize the differences between two
# long lines.
error_message = _add_aligned_message(expected[first_diff], actual[first_diff],
error_message)
first_failure = expected[first_diff]
if first_failure:
lineno = int(first_failure.split(' ')[0].strip(':').split(':')[1])
raise TypecheckAssertionError(error_message=f'Invalid output: \n{error_message}',
lineno=lineno)
def build_output_line(fname: str, lnum: int, severity: str, message: str, col=None) -> str:
if col is None:
return f'{fname}:{lnum + 1}: {severity}: {message}'
else:
return f'{fname}:{lnum + 1}:{col}: {severity}: {message}'
def expand_errors(input_lines: List[str], fname: str) -> List[str]:
"""Transform comments such as '# E: message' or
'# E:3: message' in input.
The result is lines like 'fnam:line: error: message'.
"""
output_lines = []
for lnum, line in enumerate(input_lines):
# The first in the split things isn't a comment
for possible_err_comment in line.split(' # ')[1:]:
m = re.search(
r'^([ENW]):((?P<col>\d+):)? (?P<message>.*)$',
possible_err_comment.strip())
if m:
if m.group(1) == 'E':
severity = 'error'
elif m.group(1) == 'N':
severity = 'note'
elif m.group(1) == 'W':
severity = 'warning'
col = m.group('col')
output_lines.append(build_output_line(fname, lnum, severity,
message=m.group("message"),
col=col))
return output_lines
def get_func_first_lnum(attr: Callable[..., None]) -> Optional[Tuple[int, List[str]]]:
lines, _ = inspect.getsourcelines(attr)
for lnum, line in enumerate(lines):
no_space_line = line.strip()
if f'def {attr.__name__}' in no_space_line:
return lnum, lines[lnum + 1:]
raise ValueError(f'No line "def {attr.__name__}" found')

299
test/pytest_plugin.py Normal file
View File

@@ -0,0 +1,299 @@
import dataclasses
import inspect
import os
import sys
import tempfile
import textwrap
from contextlib import contextmanager
from pathlib import Path
from typing import Iterator, Any, Optional, cast, List, Type, Callable, Dict
import pytest
from _pytest._code.code import ReprFileLocation, ReprEntry, ExceptionInfo
from decorator import decorate
from mypy import api as mypy_api
from test import vistir
from test.helpers import assert_string_arrays_equal, TypecheckAssertionError, expand_errors, get_func_first_lnum
def reveal_type(obj: Any) -> None:
# noop method, just to get rid of "method is not resolved" errors
pass
def output(output_lines: str):
def decor(func: Callable[..., None]):
func.out = output_lines
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return decorate(func, wrapper)
return decor
def get_class_that_defined_method(meth) -> Type['MypyTypecheckTestCase']:
if inspect.ismethod(meth):
for cls in inspect.getmro(meth.__self__.__class__):
if cls.__dict__.get(meth.__name__) is meth:
return cls
meth = meth.__func__ # fallback to __qualname__ parsing
if inspect.isfunction(meth):
cls = getattr(inspect.getmodule(meth),
meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0])
if issubclass(cls, MypyTypecheckTestCase):
return cls
return getattr(meth, '__objclass__', None) # handle special descriptor objects
def file(filename: str, make_parent_packages=False):
def decor(func: Callable[..., None]):
func.filename = filename
func.make_parent_packages = make_parent_packages
return func
return decor
def env(**environ):
def decor(func: Callable[..., None]):
func.env = environ
return func
return decor
@dataclasses.dataclass
class CreateFile:
sources: str
make_parent_packages: bool = False
class MypyTypecheckMeta(type):
def __new__(mcs, name, bases, attrs):
cls = super().__new__(mcs, name, bases, attrs)
cls.files: Dict[str, CreateFile] = {}
for name, attr in attrs.items():
if inspect.isfunction(attr):
filename = getattr(attr, 'filename', None)
if not filename:
continue
make_parent_packages = getattr(attr, 'make_parent_packages', False)
sources = textwrap.dedent(''.join(get_func_first_lnum(attr)[1]))
if sources.strip() == 'pass':
sources = ''
cls.files[filename] = CreateFile(sources, make_parent_packages)
return cls
class MypyTypecheckTestCase(metaclass=MypyTypecheckMeta):
files = None
def ini_file(self) -> str:
return """
[mypy]
"""
def _get_ini_file_contents(self) -> Optional[str]:
raw_ini_file = self.ini_file()
if not raw_ini_file:
return raw_ini_file
return raw_ini_file.strip() + '\n'
class TraceLastReprEntry(ReprEntry):
def toterminal(self, tw):
self.reprfileloc.toterminal(tw)
for line in self.lines:
red = line.startswith("E ")
tw.line(line, bold=True, red=red)
return
def fname_to_module(fpath: Path, root_path: Path) -> Optional[str]:
try:
relpath = fpath.relative_to(root_path).with_suffix('')
return str(relpath).replace(os.sep, '.')
except ValueError:
return None
class MypyTypecheckItem(pytest.Item):
root_directory = '/run/testdata'
def __init__(self,
name: str,
parent: 'MypyTestsCollector',
klass: Type[MypyTypecheckTestCase],
source_code: str,
first_lineno: int,
ini_file_contents: Optional[str] = None,
expected_output_lines: Optional[List[str]] = None,
files: Optional[Dict[str, CreateFile]] = None,
custom_environment: Optional[Dict[str, Any]] = None):
super().__init__(name=name, parent=parent)
self.klass = klass
self.source_code = source_code
self.first_lineno = first_lineno
self.ini_file_contents = ini_file_contents
self.expected_output_lines = expected_output_lines
self.files = files
self.custom_environment = custom_environment
@contextmanager
def temp_directory(self) -> Path:
with tempfile.TemporaryDirectory(prefix='mypy-pytest-',
dir=self.root_directory) as tmpdir_name:
yield Path(self.root_directory) / tmpdir_name
def runtest(self):
with self.temp_directory() as tmpdir_path:
if not self.source_code:
return
if self.ini_file_contents:
mypy_ini_fpath = tmpdir_path / 'mypy.ini'
mypy_ini_fpath.write_text(self.ini_file_contents)
test_specific_modules = []
for fname, create_file in self.files.items():
fpath = tmpdir_path / fname
if create_file.make_parent_packages:
fpath.parent.mkdir(parents=True, exist_ok=True)
for parent in fpath.parents:
try:
parent.relative_to(tmpdir_path)
if parent != tmpdir_path:
parent_init_file = parent / '__init__.py'
parent_init_file.write_text('')
test_specific_modules.append(fname_to_module(parent,
root_path=tmpdir_path))
except ValueError:
break
fpath.write_text(create_file.sources)
test_specific_modules.append(fname_to_module(fpath,
root_path=tmpdir_path))
with vistir.temp_environ(), vistir.temp_path():
for key, val in (self.custom_environment or {}).items():
os.environ[key] = val
sys.path.insert(0, str(tmpdir_path))
mypy_cmd_options = self.prepare_mypy_cmd_options(config_file_path=mypy_ini_fpath)
main_fpath = tmpdir_path / 'main.py'
main_fpath.write_text(self.source_code)
mypy_cmd_options.append(str(main_fpath))
stdout, _, _ = mypy_api.run(mypy_cmd_options)
output_lines = []
for line in stdout.splitlines():
if ':' not in line:
continue
out_fpath, res_line = line.split(':', 1)
line = os.path.relpath(out_fpath, start=tmpdir_path) + ':' + res_line
output_lines.append(line.strip().replace('.py', ''))
for module in test_specific_modules:
if module in sys.modules:
del sys.modules[module]
raise ValueError
assert_string_arrays_equal(expected=self.expected_output_lines,
actual=output_lines)
def prepare_mypy_cmd_options(self, config_file_path: Path) -> List[str]:
mypy_cmd_options = [
'--show-traceback',
'--no-silence-site-packages'
]
python_version = '.'.join([str(part) for part in sys.version_info[:2]])
mypy_cmd_options.append(f'--python-version={python_version}')
if self.ini_file_contents:
mypy_cmd_options.append(f'--config-file={config_file_path}')
return mypy_cmd_options
def repr_failure(self, excinfo: ExceptionInfo) -> str:
if excinfo.errisinstance(SystemExit):
# We assume that before doing exit() (which raises SystemExit) we've printed
# enough context about what happened so that a stack trace is not useful.
# In particular, uncaught exceptions during semantic analysis or type checking
# call exit() and they already print out a stack trace.
return excinfo.exconly(tryshort=True)
elif excinfo.errisinstance(TypecheckAssertionError):
# with traceback removed
exception_repr = excinfo.getrepr(style='short')
exception_repr.reprcrash.message = ''
repr_file_location = ReprFileLocation(path=inspect.getfile(self.klass),
lineno=self.first_lineno + excinfo.value.lineno,
message='')
repr_tb_entry = TraceLastReprEntry(filelocrepr=repr_file_location,
lines=exception_repr.reprtraceback.reprentries[-1].lines[1:],
style='short',
reprlocals=None,
reprfuncargs=None)
exception_repr.reprtraceback.reprentries = [repr_tb_entry]
return exception_repr
else:
return super().repr_failure(excinfo, style='short')
def reportinfo(self):
return self.fspath, None, get_class_qualname(self.klass) + '::' + self.name
def get_class_qualname(klass: type) -> str:
return klass.__module__ + '.' + klass.__name__
def extract_test_output(attr: Callable[..., None]) -> List[str]:
out_data: str = getattr(attr, 'out', None)
out_lines = []
if out_data:
for line in out_data.split('\n'):
line = line.strip()
out_lines.append(line)
return out_lines
class MypyTestsCollector(pytest.Class):
def get_ini_file_contents(self, contents: str) -> str:
return contents.strip() + '\n'
def collect(self) -> Iterator[pytest.Item]:
current_testcase = cast(MypyTypecheckTestCase, self.obj())
ini_file_contents = self.get_ini_file_contents(current_testcase.ini_file())
for attr_name in dir(current_testcase):
if attr_name.startswith('_test_'):
attr = getattr(self.obj, attr_name)
if inspect.isfunction(attr):
first_line_lnum, source_lines = get_func_first_lnum(attr)
func_first_line_in_file = inspect.getsourcelines(attr)[1] + first_line_lnum
output_from_decorator = extract_test_output(attr)
output_from_comments = expand_errors(source_lines, 'main')
custom_env = getattr(attr, 'env', None)
main_source_code = textwrap.dedent(''.join(source_lines))
yield MypyTypecheckItem(name=attr_name,
parent=self,
klass=current_testcase.__class__,
source_code=main_source_code,
first_lineno=func_first_line_in_file,
ini_file_contents=ini_file_contents,
expected_output_lines=output_from_comments
+ output_from_decorator,
files=current_testcase.__class__.files,
custom_environment=custom_env)
def pytest_pycollect_makeitem(collector: Any, name: str, obj: Any) -> Optional[MypyTestsCollector]:
# Only classes derived from DataSuite contain test cases, not the DataSuite class itself
if (isinstance(obj, type)
and issubclass(obj, MypyTypecheckTestCase)
and obj is not MypyTypecheckTestCase):
# Non-None result means this obj is a test case.
# The collect method of the returned DataSuiteCollector instance will be called later,
# with self.obj being obj.
return MypyTestsCollector(name, parent=collector)

View File

View File

@@ -0,0 +1,9 @@
from test.pytest_plugin import MypyTypecheckTestCase
class BaseDjangoPluginTestCase(MypyTypecheckTestCase):
def ini_file(self):
return """
[mypy]
plugins = mypy_django_plugin.main
"""

View File

@@ -0,0 +1,21 @@
from test.pytest_plugin import reveal_type
from test.pytest_tests.base import BaseDjangoPluginTestCase
class TestBasicModelFields(BaseDjangoPluginTestCase):
def test_model_field_classes_present_as_primitives(self):
from django.db import models
class User(models.Model):
id = models.AutoField(primary_key=True)
small_int = models.SmallIntegerField()
name = models.CharField(max_length=255)
slug = models.SlugField(max_length=255)
text = models.TextField()
user = User()
reveal_type(user.id) # E: Revealed type is 'builtins.int'
reveal_type(user.small_int) # E: Revealed type is 'builtins.int'
reveal_type(user.name) # E: Revealed type is 'builtins.str'
reveal_type(user.slug) # E: Revealed type is 'builtins.str'
reveal_type(user.text) # E: Revealed type is 'builtins.str'

View File

@@ -1,16 +1,9 @@
from test.pytest_plugin import MypyTypecheckTestCase, reveal_type
from test.pytest_plugin import reveal_type
from test.pytest_tests.base import BaseDjangoPluginTestCase
class BaseDjangoPluginTestCase(MypyTypecheckTestCase):
def ini_file(self):
return """
[mypy]
plugins = mypy_django_plugin.main
"""
class MyTestCase(BaseDjangoPluginTestCase):
def check_foreign_key_field(self):
class TestForeignKey(BaseDjangoPluginTestCase):
def test_foreign_key_field(self):
from django.db import models
class Publisher(models.Model):
@@ -26,7 +19,7 @@ class MyTestCase(BaseDjangoPluginTestCase):
publisher = Publisher()
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
def check_every_foreign_key_creates_field_name_with_appended_id(self):
def test_every_foreign_key_creates_field_name_with_appended_id(self):
from django.db import models
class Publisher(models.Model):
@@ -39,7 +32,7 @@ class MyTestCase(BaseDjangoPluginTestCase):
book = Book()
reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
def check_foreign_key_different_order_of_params(self):
def test_foreign_key_different_order_of_params(self):
from django.db import models
class Publisher(models.Model):
@@ -54,3 +47,43 @@ class MyTestCase(BaseDjangoPluginTestCase):
publisher = Publisher()
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
class TestOneToOneField(BaseDjangoPluginTestCase):
def test_onetoone_field(self):
from django.db import models
class User(models.Model):
pass
class Profile(models.Model):
user = models.OneToOneField(to=User, on_delete=models.CASCADE, related_name='profile')
profile = Profile()
reveal_type(profile.user) # E: Revealed type is 'main.User*'
user = User()
reveal_type(user.profile) # E: Revealed type is 'main.Profile'
def test_onetoone_field_with_underscore_id(self):
from django.db import models
class User(models.Model):
pass
class Profile(models.Model):
user = models.OneToOneField(to=User, on_delete=models.CASCADE, related_name='profile')
profile = Profile()
reveal_type(profile.user_id) # E: Revealed type is 'builtins.int'
def test_parameter_to_keyword_may_be_absent(self):
from django.db import models
class User(models.Model):
pass
class Profile(models.Model):
user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile')
reveal_type(User().profile) # E: Revealed type is 'main.Profile'

View File

@@ -0,0 +1,28 @@
from test.pytest_plugin import reveal_type, output
from test.pytest_tests.base import BaseDjangoPluginTestCase
class TestObjectsQueryset(BaseDjangoPluginTestCase):
def test_every_model_has_objects_queryset_available(self):
from django.db import models
class User(models.Model):
pass
reveal_type(User.objects) # E: Revealed type is 'django.db.models.query.QuerySet[main.User]'
@output("""
main:10: error: Revealed type is 'Any'
main:10: error: "Type[ModelMixin]" has no attribute "objects"
""")
def test_objects_get_returns_model_instance(self):
from django.db import models
class ModelMixin(models.Model):
class Meta:
abstract = True
class User(ModelMixin):
pass
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'

View File

@@ -0,0 +1,37 @@
from test.pytest_plugin import reveal_type, file, env
from test.pytest_tests.base import BaseDjangoPluginTestCase
class TestParseSettingsFromFile(BaseDjangoPluginTestCase):
@env(DJANGO_SETTINGS_MODULE='mysettings')
def test_case(self):
from django.conf import settings
reveal_type(settings.ROOT_DIR) # E: Revealed type is 'builtins.str'
reveal_type(settings.OBJ) # E: Revealed type is 'django.utils.functional.LazyObject'
reveal_type(settings.NUMBERS) # E: Revealed type is 'builtins.list[Any]'
reveal_type(settings.DICT) # E: Revealed type is 'builtins.dict[Any, Any]'
@file('mysettings.py')
def mysettings_py_file(self):
SECRET_KEY = 112233
ROOT_DIR = '/etc'
NUMBERS = ['one', 'two']
DICT = {} # type: ignore
from django.utils.functional import LazyObject
OBJ = LazyObject()
class TestSettingInitializableToNone(BaseDjangoPluginTestCase):
@env(DJANGO_SETTINGS_MODULE='mysettings')
def test_case(self):
from django.conf import settings
reveal_type(settings.NONE_SETTING) # E: Revealed type is 'builtins.object'
@file('mysettings.py')
def mysettings_py_file(self):
SECRET_KEY = 112233
NONE_SETTING = None

View File

@@ -0,0 +1,27 @@
from test.pytest_plugin import reveal_type
from test.pytest_tests.base import BaseDjangoPluginTestCase
class TestArrayField(BaseDjangoPluginTestCase):
def test_descriptor_access(self):
from django.db import models
from django.contrib.postgres.fields import ArrayField
class User(models.Model):
array = ArrayField(base_field=models.Field())
user = User()
reveal_type(user.array) # E: Revealed type is 'builtins.list[Any]'
def test_base_field_parsed_into_generic_attribute(self):
from django.db import models
from django.contrib.postgres.fields import ArrayField
class User(models.Model):
members = ArrayField(base_field=models.IntegerField())
members_as_text = ArrayField(base_field=models.CharField(max_length=255))
user = User()
reveal_type(user.members) # E: Revealed type is 'builtins.list[builtins.int*]'
reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list[builtins.str*]'

View File

@@ -0,0 +1,74 @@
from test.pytest_plugin import file, reveal_type, env
from test.pytest_tests.base import BaseDjangoPluginTestCase
class TestForeignKey(BaseDjangoPluginTestCase):
@env(DJANGO_SETTINGS_MODULE='mysettings')
def _test_to_parameter_could_be_specified_as_string(self):
from apps.myapp.models import Publisher
publisher = Publisher()
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[apps.myapp2.models.Book]'
# @env(DJANGO_SETTINGS_MODULE='mysettings')
# def _test_creates_underscore_id_attr(self):
# from apps.myapp2.models import Book
#
# book = Book()
# reveal_type(book.publisher) # E: Revealed type is 'apps.myapp.models.Publisher'
# reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
@file('mysettings.py')
def mysettings(self):
SECRET_KEY = '112233'
ROOT_DIR = '<TMP>'
APPS_DIR = '<TMP>/apps'
INSTALLED_APPS = ('apps.myapp', 'apps.myapp2')
@file('apps/myapp/models.py', make_parent_packages=True)
def apps_myapp_models(self):
from django.db import models
class Publisher(models.Model):
pass
@file('apps/myapp2/models.py', make_parent_packages=True)
def apps_myapp2_models(self):
from django.db import models
class Book(models.Model):
publisher = models.ForeignKey(to='myapp.Publisher', on_delete=models.CASCADE,
related_name='books')
class TestOneToOneField(BaseDjangoPluginTestCase):
@env(DJANGO_SETTINGS_MODULE='mysettings')
def test_to_parameter_could_be_specified_as_string(self):
from apps.myapp.models import User
user = User()
reveal_type(user.profile) # E: Revealed type is 'apps.myapp2.models.Profile'
@file('mysettings.py')
def mysettings(self):
SECRET_KEY = '112233'
ROOT_DIR = '<TMP>'
APPS_DIR = '<TMP>/apps'
INSTALLED_APPS = ('apps.myapp', 'apps.myapp2')
@file('apps/myapp/models.py', make_parent_packages=True)
def apps_myapp_models(self):
from django.db import models
class User(models.Model):
pass
@file('apps/myapp2/models.py', make_parent_packages=True)
def apps_myapp2_models(self):
from django.db import models
class Profile(models.Model):
user = models.OneToOneField(to='myapp.User', on_delete=models.CASCADE,
related_name='profile')

View File

@@ -14,11 +14,14 @@ MYPY_INI_PATH = ROOT_DIR / 'test' / 'plugins.ini'
class DjangoTestSuite(DataSuite):
files = [
'check-objects-queryset.test',
'check-model-fields.test',
'check-postgres-fields.test',
'check-model-relations.test',
'check-parse-settings.test'
# 'check-objects-queryset.test',
# 'check-model-fields.test',
# 'check-postgres-fields.test',
# 'check-model-relations.test',
# 'check-parse-settings.test',
# 'check-to-attr-as-string-one-to-one-field.test',
'check-to-attr-as-string-foreign-key.test',
# 'check-foreign-key-as-string-creates-underscore-id-attr.test'
]
data_prefix = str(TEST_DATA_DIR)

43
test/vistir.py Normal file
View File

@@ -0,0 +1,43 @@
# Borrowed from Pew.
# See https://github.com/berdario/pew/blob/master/pew/_utils.py#L82
import os
import sys
from pathlib import Path
from decorator import contextmanager
@contextmanager
def temp_environ():
"""Allow the ability to set os.environ temporarily"""
environ = dict(os.environ)
try:
yield
finally:
os.environ.clear()
os.environ.update(environ)
@contextmanager
def temp_path():
"""A context manager which allows the ability to set sys.path temporarily"""
path = [p for p in sys.path]
try:
yield
finally:
sys.path = [p for p in path]
@contextmanager
def cd(path):
"""Context manager to temporarily change working directories"""
if not path:
return
prev_cwd = Path.cwd().as_posix()
if isinstance(path, Path):
path = path.as_posix()
os.chdir(str(path))
try:
yield
finally:
os.chdir(prev_cwd)