From f59cfe63713eff8043c1c9a29cf6ba435dcad9d6 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Mon, 26 Nov 2018 23:58:34 +0300 Subject: [PATCH] latest changes --- .gitignore | 3 +- conftest.py | 2 +- django-stubs/apps/__init__.pyi | 7 + django-stubs/apps/config.pyi | 26 ++ django-stubs/apps/registry.pyi | 62 ++++ django-stubs/db/__init__.pyi | 18 +- django-stubs/db/models/__init__.pyi | 20 +- django-stubs/db/models/expressions.pyi | 49 ++- django-stubs/db/models/fields/__init__.pyi | 39 ++- django-stubs/db/models/fields/related.pyi | 9 + django-stubs/db/models/manager.pyi | 140 ++++++++ django-stubs/urls/__init__.pyi | 7 +- django-stubs/utils/baseconv.pyi | 30 ++ mypy_django_plugin/helpers.py | 18 +- mypy_django_plugin/main.py | 59 +++- mypy_django_plugin/monkeypatch.py | 112 +++++++ .../plugins/objects_queryset.py | 31 +- mypy_django_plugin/plugins/postgres_fields.py | 7 +- mypy_django_plugin/plugins/related_fields.py | 141 ++++++--- mypy_django_plugin/plugins/setup_settings.py | 34 +- pytest.ini | 3 +- test/data.py | 8 +- test/helpers.py | 253 +++++++++++++++ test/pytest_plugin.py | 299 ++++++++++++++++++ test/pytest_tests/__init__.py | 0 test/pytest_tests/base.py | 9 + test/pytest_tests/test_model_fields.py | 21 ++ test/pytest_tests/test_model_relations.py | 61 +++- test/pytest_tests/test_objects_queryset.py | 28 ++ test/pytest_tests/test_parse_settings.py | 37 +++ test/pytest_tests/test_postgres_fields.py | 27 ++ test/pytest_tests/test_to_attr_as_string.py | 74 +++++ test/testdjango.py | 13 +- test/vistir.py | 43 +++ 34 files changed, 1558 insertions(+), 132 deletions(-) create mode 100644 django-stubs/apps/__init__.pyi create mode 100644 django-stubs/apps/config.pyi create mode 100644 django-stubs/apps/registry.pyi create mode 100644 django-stubs/db/models/manager.pyi create mode 100644 django-stubs/utils/baseconv.pyi create mode 100644 mypy_django_plugin/monkeypatch.py create mode 100644 test/helpers.py create mode 100644 test/pytest_plugin.py create mode 100644 test/pytest_tests/__init__.py create mode 100644 test/pytest_tests/base.py create mode 100644 test/pytest_tests/test_model_fields.py create mode 100644 test/pytest_tests/test_objects_queryset.py create mode 100644 test/pytest_tests/test_parse_settings.py create mode 100644 test/pytest_tests/test_postgres_fields.py create mode 100644 test/pytest_tests/test_to_attr_as_string.py create mode 100644 test/vistir.py diff --git a/.gitignore b/.gitignore index 25e61e2..bdfeb0e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ __pycache__/ out/ /test_sqlite.py /django -.idea/ \ No newline at end of file +.idea/ +.mypy_cache/ \ No newline at end of file diff --git a/conftest.py b/conftest.py index 8b7b61b..04cba73 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,3 @@ pytest_plugins = [ - 'test.data' + 'test.pytest_plugin' ] \ No newline at end of file diff --git a/django-stubs/apps/__init__.pyi b/django-stubs/apps/__init__.pyi new file mode 100644 index 0000000..9c16065 --- /dev/null +++ b/django-stubs/apps/__init__.pyi @@ -0,0 +1,7 @@ +from .config import ( + AppConfig as AppConfig +) + +from .registry import ( + apps as apps +) diff --git a/django-stubs/apps/config.pyi b/django-stubs/apps/config.pyi new file mode 100644 index 0000000..277238f --- /dev/null +++ b/django-stubs/apps/config.pyi @@ -0,0 +1,26 @@ +from typing import Any, Iterator, Type + +from django.db.models.base import Model + +MODELS_MODULE_NAME: str + +class AppConfig: + name: str = ... + module: Any = ... + apps: None = ... + label: str = ... + verbose_name: str = ... + path: str = ... + models_module: None = ... + models: None = ... + def __init__(self, app_name: str, app_module: None) -> None: ... + @classmethod + def create(cls, entry: str) -> AppConfig: ... + def get_model( + self, model_name: str, require_ready: bool = ... + ) -> Type[Model]: ... + def get_models( + self, include_auto_created: bool = ..., include_swapped: bool = ... + ) -> Iterator[Type[Model]]: ... + def import_models(self) -> None: ... + def ready(self) -> None: ... diff --git a/django-stubs/apps/registry.pyi b/django-stubs/apps/registry.pyi new file mode 100644 index 0000000..28c1426 --- /dev/null +++ b/django-stubs/apps/registry.pyi @@ -0,0 +1,62 @@ +import collections +from typing import Any, Callable, List, Optional, Tuple, Type, Union, Iterable + +from django.apps.config import AppConfig +from django.db.migrations.state import AppConfigStub +from django.db.models.base import Model + +from .config import AppConfig + + +class Apps: + all_models: collections.defaultdict = ... + app_configs: collections.OrderedDict = ... + stored_app_configs: List[Any] = ... + apps_ready: bool = ... + loading: bool = ... + def __init__( + self, + installed_apps: Optional[ + Union[List[AppConfigStub], List[str], Tuple] + ] = ..., + ) -> None: ... + models_ready: bool = ... + ready: bool = ... + def populate( + self, installed_apps: Union[List[AppConfigStub], List[str], Tuple] = ... + ) -> None: ... + def check_apps_ready(self) -> None: ... + def check_models_ready(self) -> None: ... + def get_app_configs(self) -> Iterable[AppConfig]: ... + def get_app_config(self, app_label: str) -> AppConfig: ... + def get_models( + self, include_auto_created: bool = ..., include_swapped: bool = ... + ) -> List[Type[Model]]: ... + def get_model( + self, + app_label: str, + model_name: Optional[str] = ..., + require_ready: bool = ..., + ) -> Type[Model]: ... + def register_model(self, app_label: str, model: Type[Model]) -> None: ... + def is_installed(self, app_name: str) -> bool: ... + def get_containing_app_config( + self, object_name: str + ) -> Optional[AppConfig]: ... + def get_registered_model( + self, app_label: str, model_name: str + ) -> Type[Model]: ... + def get_swappable_settings_name(self, to_string: str) -> Optional[str]: ... + def set_available_apps(self, available: List[str]) -> None: ... + def unset_available_apps(self) -> None: ... + def set_installed_apps( + self, installed: Union[List[str], Tuple[str]] + ) -> None: ... + def unset_installed_apps(self) -> None: ... + def clear_cache(self) -> None: ... + def lazy_model_operation( + self, function: Callable, *model_keys: Any + ) -> None: ... + def do_pending_operations(self, model: Type[Model]) -> None: ... + +apps: Apps diff --git a/django-stubs/db/__init__.pyi b/django-stubs/db/__init__.pyi index 75807c9..904d895 100644 --- a/django-stubs/db/__init__.pyi +++ b/django-stubs/db/__init__.pyi @@ -1,2 +1,18 @@ +from typing import Any + from .utils import (ProgrammingError as ProgrammingError, - IntegrityError as IntegrityError) \ No newline at end of file + IntegrityError as IntegrityError, + OperationalError as OperationalError, + DatabaseError as DatabaseError, + DataError as DataError, + NotSupportedError as NotSupportedError) + +connections: Any +router: Any + +class DefaultConnectionProxy: + def __getattr__(self, item: str) -> Any: ... + def __setattr__(self, name: str, value: Any) -> None: ... + def __delattr__(self, name: str) -> None: ... + +connection: Any diff --git a/django-stubs/db/models/__init__.pyi b/django-stubs/db/models/__init__.pyi index 06b5e66..326b08f 100644 --- a/django-stubs/db/models/__init__.pyi +++ b/django-stubs/db/models/__init__.pyi @@ -3,14 +3,21 @@ from .base import Model as Model from .fields import (AutoField as AutoField, IntegerField as IntegerField, SmallIntegerField as SmallIntegerField, + BigIntegerField as BigIntegerField, CharField as CharField, Field as Field, SlugField as SlugField, TextField as TextField, - BooleanField as BooleanField) + BooleanField as BooleanField, + FileField as FileField, + DateField as DateField, + DateTimeField as DateTimeField, + IPAddressField as IPAddressField, + GenericIPAddressField as GenericIPAddressField) from .fields.related import (ForeignKey as ForeignKey, - OneToOneField as OneToOneField) + OneToOneField as OneToOneField, + ManyToManyField as ManyToManyField) from .deletion import (CASCADE as CASCADE, SET_DEFAULT as SET_DEFAULT, @@ -24,4 +31,11 @@ from .query_utils import Q as Q from .lookups import Lookup as Lookup -from .expressions import F as F +from .expressions import (F as F, + Subquery as Subquery, + Exists as Exists, + OrderBy as OrderBy, + OuterRef as OuterRef) + +from .manager import (BaseManager as BaseManager, + Manager as Manager) diff --git a/django-stubs/db/models/expressions.pyi b/django-stubs/db/models/expressions.pyi index 3b9c362..61ec416 100644 --- a/django-stubs/db/models/expressions.pyi +++ b/django-stubs/db/models/expressions.pyi @@ -3,6 +3,7 @@ from datetime import datetime, timedelta from typing import (Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union) +from django.db.models import QuerySet from django.db.models.fields import Field from django.db.models.lookups import Lookup from django.db.models.sql.compiler import SQLCompiler @@ -180,4 +181,50 @@ class CombinedExpression(SQLiteNumericMixin, Expression): class F(Combinable): - def __init__(self, name: str): ... \ No newline at end of file + name: str + def __init__(self, name: str): ... + def resolve_expression( + self, + query: Any = ..., + allow_joins: bool = ..., + reuse: Optional[Set[str]] = ..., + summarize: bool = ..., + for_save: bool = ..., + ) -> Expression: ... + + +class OuterRef(F): ... + +class Subquery(Expression): + template: str = ... + queryset: QuerySet = ... + extra: Dict[Any, Any] = ... + def __init__( + self, + queryset: QuerySet, + output_field: Optional[Field] = ..., + **extra: Any + ) -> None: ... + +class Exists(Subquery): + extra: Dict[Any, Any] + template: str = ... + negated: bool = ... + def __init__( + self, *args: Any, negated: bool = ..., **kwargs: Any + ) -> None: ... + def __invert__(self) -> Exists: ... + +class OrderBy(BaseExpression): + template: str = ... + nulls_first: bool = ... + nulls_last: bool = ... + descending: bool = ... + expression: Expression = ... + def __init__( + self, + expression: Combinable, + descending: bool = ..., + nulls_first: bool = ..., + nulls_last: bool = ..., + ) -> None: ... diff --git a/django-stubs/db/models/fields/__init__.pyi b/django-stubs/db/models/fields/__init__.pyi index 42374ff..9d9e7f5 100644 --- a/django-stubs/db/models/fields/__init__.pyi +++ b/django-stubs/db/models/fields/__init__.pyi @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from django.db.models.query_utils import RegisterLookupMixin @@ -7,6 +7,7 @@ class Field(RegisterLookupMixin): def __init__(self, primary_key: bool = False, **kwargs): ... + def __get__(self, instance, owner) -> Any: ... @@ -14,8 +15,10 @@ class IntegerField(Field): def __get__(self, instance, owner) -> int: ... -class SmallIntegerField(IntegerField): - pass +class SmallIntegerField(IntegerField): ... + + +class BigIntegerField(IntegerField): ... class AutoField(Field): @@ -26,11 +29,11 @@ class CharField(Field): def __init__(self, max_length: int, **kwargs): ... + def __get__(self, instance, owner) -> str: ... -class SlugField(CharField): - pass +class SlugField(CharField): ... class TextField(Field): @@ -39,3 +42,29 @@ class TextField(Field): class BooleanField(Field): def __get__(self, instance, owner) -> bool: ... + + +class FileField(Field): ... + + +class IPAddressField(Field): ... + + +class GenericIPAddressField(Field): + default_error_messages: Any = ... + unpack_ipv4: Any = ... + protocol: Any = ... + + def __init__( + self, + verbose_name: Optional[Any] = ..., + name: Optional[Any] = ..., + protocol: str = ..., + unpack_ipv4: bool = ..., + *args: Any, + **kwargs: Any + ) -> None: ... + +class DateField(Field): ... + +class DateTimeField(DateField): ... \ No newline at end of file diff --git a/django-stubs/db/models/fields/related.pyi b/django-stubs/db/models/fields/related.pyi index c6a7ac4..376e546 100644 --- a/django-stubs/db/models/fields/related.pyi +++ b/django-stubs/db/models/fields/related.pyi @@ -22,3 +22,12 @@ class OneToOneField(Field, Generic[_T]): related_name: str = ..., **kwargs): ... def __get__(self, instance, owner) -> _T: ... + + +class ManyToManyField(Field, Generic[_T]): + def __init__(self, + to: Union[Type[_T], str], + on_delete: Any, + related_name: str = ..., + **kwargs): ... + def __get__(self, instance, owner) -> _T: ... diff --git a/django-stubs/db/models/manager.pyi b/django-stubs/db/models/manager.pyi new file mode 100644 index 0000000..2e9f903 --- /dev/null +++ b/django-stubs/db/models/manager.pyi @@ -0,0 +1,140 @@ +from collections import OrderedDict +from datetime import datetime +from decimal import Decimal +from typing import Any, Dict, List, Optional, Tuple, Type, Union, TypeVar, Set, Generic, Iterator +from unittest.mock import MagicMock + +from django.db.models import Q +from django.db.models.base import Model +from django.db.models.query import QuerySet, RawQuerySet + +_T = TypeVar('_T', bound=Model) + + +class BaseManager: + creation_counter: int = ... + auto_created: bool = ... + use_in_migrations: bool = ... + def __new__(cls: Type[BaseManager], *args: Any, **kwargs: Any) -> BaseManager: ... + model: Any = ... + name: Any = ... + def __init__(self) -> None: ... + def deconstruct(self) -> Tuple[bool, str, None, Tuple, Dict[str, int]]: ... + def check(self, **kwargs: Any) -> List[Any]: ... + @classmethod + def from_queryset( + cls, queryset_class: Any, class_name: Optional[Any] = ... + ): ... + def contribute_to_class(self, model: Type[Model], name: str) -> None: ... + def db_manager( + self, + using: Optional[str] = ..., + hints: Optional[Dict[str, Model]] = ..., + ) -> Manager: ... + @property + def db(self) -> str: ... + def get_queryset(self) -> QuerySet: ... + def all(self) -> QuerySet: ... + def __eq__(self, other: Optional[Any]) -> bool: ... + def __hash__(self): ... + +class Manager(Generic[_T]): + def exists(self) -> bool: ... + def explain( + self, *, format: Optional[Any] = ..., **options: Any + ) -> str: ... + def raw( + self, + raw_query: str, + params: Optional[ + Union[ + Dict[str, str], + List[datetime], + List[Decimal], + List[str], + Set[str], + Tuple[int], + ] + ] = ..., + translations: Optional[Dict[str, str]] = ..., + using: None = ..., + ) -> RawQuerySet: ... + def values(self, *fields: Any, **expressions: Any) -> QuerySet: ... + def values_list( + self, *fields: Any, flat: bool = ..., named: bool = ... + ) -> QuerySet: ... + def dates( + self, field_name: str, kind: str, order: str = ... + ) -> QuerySet: ... + def datetimes( + self, field_name: str, kind: str, order: str = ..., tzinfo: None = ... + ) -> QuerySet: ... + def none(self) -> QuerySet[_T]: ... + def all(self) -> QuerySet[_T]: ... + def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ... + def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ... + def complex_filter( + self, + filter_obj: Union[ + Dict[str, datetime], Dict[str, QuerySet], Q, MagicMock + ], + ) -> QuerySet[_T]: ... + def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ... + def intersection(self, *other_qs: Any) -> QuerySet[_T]: ... + + def difference(self, *other_qs: Any) -> QuerySet[_T]: ... + + def select_for_update( + self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ... + ) -> QuerySet: ... + + def select_related(self, *fields: Any) -> QuerySet[_T]: ... + + def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ... + + def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ... + + def order_by(self, *field_names: Any) -> QuerySet[_T]: ... + + def distinct(self, *field_names: Any) -> QuerySet[_T]: ... + + def extra( + self, + select: Optional[ + Union[Dict[str, int], Dict[str, str], OrderedDict] + ] = ..., + where: Optional[List[str]] = ..., + params: Optional[Union[List[int], List[str]]] = ..., + tables: Optional[List[str]] = ..., + order_by: Optional[Union[List[str], Tuple[str]]] = ..., + select_params: Optional[Union[List[int], List[str], Tuple[int]]] = ..., + ) -> QuerySet[_T]: ... + + def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ... + + def aggregate( + self, *args: Any, **kwargs: Any + ) -> Dict[str, Optional[Union[datetime, float]]]: ... + + def count(self) -> int: ... + + def get( + self, *args: Any, **kwargs: Any + ) -> _T: ... + + def create(self, **kwargs: Any) -> _T: ... + + +class ManagerDescriptor: + manager: Manager = ... + def __init__(self, manager: Manager) -> None: ... + def __get__( + self, instance: Optional[Model], cls: Type[Model] = ... + ) -> Manager: ... + +class EmptyManager(Manager): + creation_counter: int + name: None + model: Optional[Type[Model]] = ... + def __init__(self, model: Type[Model]) -> None: ... + def get_queryset(self) -> QuerySet: ... diff --git a/django-stubs/urls/__init__.pyi b/django-stubs/urls/__init__.pyi index 51c5a1a..f95afa0 100644 --- a/django-stubs/urls/__init__.pyi +++ b/django-stubs/urls/__init__.pyi @@ -6,4 +6,9 @@ from .conf import (include as include, from .resolvers import (ResolverMatch as ResolverMatch, get_ns_resolver as get_ns_resolver, - get_resolver as get_resolver) \ No newline at end of file + get_resolver as get_resolver) + +# noinspection PyUnresolvedReferences +from .converters import ( + register_converter as register_converter +) \ No newline at end of file diff --git a/django-stubs/utils/baseconv.pyi b/django-stubs/utils/baseconv.pyi new file mode 100644 index 0000000..811f0ad --- /dev/null +++ b/django-stubs/utils/baseconv.pyi @@ -0,0 +1,30 @@ +from typing import Any, Tuple, Union + +BASE2_ALPHABET: str +BASE16_ALPHABET: str +BASE56_ALPHABET: str +BASE36_ALPHABET: str +BASE62_ALPHABET: str +BASE64_ALPHABET: Any + +class BaseConverter: + decimal_digits: str = ... + sign: str = ... + digits: str = ... + def __init__(self, digits: str, sign: str = ...) -> None: ... + def encode(self, i: int) -> str: ... + def decode(self, s: str) -> int: ... + def convert( + self, + number: Union[int, str], + from_digits: str, + to_digits: str, + sign: str, + ) -> Tuple[int, str]: ... + +base2: Any +base16: Any +base36: Any +base56: Any +base62: Any +base64: Any diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index 13402fe..5a1c618 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -1,15 +1,15 @@ -from typing import Dict, Optional, NamedTuple +import typing +from typing import Dict, Optional, NamedTuple, Any -from mypy.semanal import SemanticAnalyzerPass2 -from mypy.types import Type from mypy.nodes import SymbolTableNode, Var, Expression from mypy.plugin import FunctionContext -from mypy.types import Instance, UnionType, NoneTyp +from mypy.types import Type, Instance, UnionType, NoneTyp MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet' FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey' ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField' +DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject' def create_new_symtable_node(name: str, kind: int, instance: Instance) -> SymbolTableNode: @@ -26,12 +26,10 @@ Argument = NamedTuple('Argument', fields=[ def get_call_signature_or_none(ctx: FunctionContext) -> Optional[Dict[str, Argument]]: - arg_names = ctx.context.arg_names - result: Dict[str, Argument] = {} positional_args_only = [] positional_arg_types_only = [] - for arg, arg_name, arg_type in zip(ctx.args, arg_names, ctx.arg_types): + for arg, arg_name, arg_type in zip(ctx.args, ctx.arg_names, ctx.arg_types): if arg_name is None: positional_args_only.append(arg) positional_arg_types_only.append(arg_type) @@ -64,4 +62,8 @@ def make_required(typ: Type) -> Type: if not isinstance(typ, UnionType): return typ items = [item for item in typ.items if not isinstance(item, NoneTyp)] - return UnionType.make_union(items) \ No newline at end of file + return UnionType.make_union(items) + + +def get_obj_type_name(typ: typing.Type) -> str: + return typ.__module__ + '.' + typ.__qualname__ diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 2227cf9..4f4368a 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -1,27 +1,42 @@ import os -from typing import Callable, Optional +from typing import Callable, Optional, List +from django.apps.registry import Apps from django.conf import Settings +from mypy import build +from mypy.build import BuildManager from mypy.options import Options from mypy.plugin import Plugin, FunctionContext, ClassDefContext from mypy.types import Type -from mypy_django_plugin import helpers +from mypy_django_plugin import helpers, monkeypatch from mypy_django_plugin.plugins.objects_queryset import set_objects_queryset_to_model_class from mypy_django_plugin.plugins.postgres_fields import determine_type_of_array_field -from mypy_django_plugin.plugins.related_fields import set_related_name_instance_for_onetoonefield, \ - set_related_name_manager_for_foreign_key, set_fieldname_attrs_for_related_fields +from mypy_django_plugin.plugins.related_fields import OneToOneFieldHook, \ + ForeignKeyHook, set_fieldname_attrs_for_related_fields from mypy_django_plugin.plugins.setup_settings import DjangoConfSettingsInitializerHook base_model_classes = {helpers.MODEL_CLASS_FULLNAME} -def transform_model_class(ctx: ClassDefContext) -> None: - base_model_classes.add(ctx.cls.fullname) +class TransformModelClassHook(object): + def __init__(self, settings: Settings, apps: Apps): + self.settings = settings + self.apps = apps - set_fieldname_attrs_for_related_fields(ctx) - set_objects_queryset_to_model_class(ctx) + def __call__(self, ctx: ClassDefContext) -> None: + base_model_classes.add(ctx.cls.fullname) + + set_fieldname_attrs_for_related_fields(ctx) + set_objects_queryset_to_model_class(ctx) + + +def always_return_none(manager: BuildManager): + return None + + +build.read_plugins_snapshot = always_return_none class DjangoPlugin(Plugin): @@ -29,18 +44,36 @@ class DjangoPlugin(Plugin): options: Options) -> None: super().__init__(options) self.django_settings = None + self.apps = None + + monkeypatch.replace_apply_function_plugin_method() django_settings_module = os.environ.get('DJANGO_SETTINGS_MODULE') if django_settings_module: self.django_settings = Settings(django_settings_module) + # import django + # django.setup() + # + # from django.apps import apps + # self.apps = apps + # + # models_modules = [] + # for app_config in self.apps.app_configs.values(): + # models_modules.append(app_config.module.__name__ + '.' + 'models') + # + # monkeypatch.state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(django_settings_module, + # models_modules) + monkeypatch.load_graph_to_add_settings_file_as_a_source_seed(django_settings_module) def get_function_hook(self, fullname: str ) -> Optional[Callable[[FunctionContext], Type]]: if fullname == helpers.FOREIGN_KEY_FULLNAME: - return set_related_name_manager_for_foreign_key + return ForeignKeyHook(settings=self.django_settings, + apps=self.apps) if fullname == helpers.ONETOONE_FIELD_FULLNAME: - return set_related_name_instance_for_onetoonefield + return OneToOneFieldHook(settings=self.django_settings, + apps=self.apps) if fullname == 'django.contrib.postgres.fields.array.ArrayField': return determine_type_of_array_field @@ -49,9 +82,11 @@ class DjangoPlugin(Plugin): def get_base_class_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: if fullname in base_model_classes: - return transform_model_class - if fullname == 'django.conf._DjangoConfLazyObject': + return TransformModelClassHook(self.django_settings, self.apps) + + if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS: return DjangoConfSettingsInitializerHook(settings=self.django_settings) + return None diff --git a/mypy_django_plugin/monkeypatch.py b/mypy_django_plugin/monkeypatch.py new file mode 100644 index 0000000..e02b243 --- /dev/null +++ b/mypy_django_plugin/monkeypatch.py @@ -0,0 +1,112 @@ +from typing import Optional, List, Sequence + +from mypy.build import BuildManager, Graph, State +from mypy.modulefinder import BuildSource +from mypy.nodes import Expression, Context +from mypy.plugin import FunctionContext, MethodContext +from mypy.types import Type, CallableType, Instance + + +def state_compute_dependencies_to_parse_installed_apps_setting_in_settings_module(settings_module: str, + models_py_modules: List[str]): + from mypy.build import State + + old_compute_dependencies = State.compute_dependencies + + def patched_compute_dependencies(self: State): + old_compute_dependencies(self) + if self.id == settings_module: + self.dependencies.extend(models_py_modules) + + State.compute_dependencies = patched_compute_dependencies + + +def load_graph_to_add_settings_file_as_a_source_seed(settings_module: str): + from mypy import build + + old_load_graph = build.load_graph + + def patched_load_graph(sources: List[BuildSource], manager: BuildManager, + old_graph: Optional[Graph] = None, + new_modules: Optional[List[State]] = None): + if all([source.module != settings_module for source in sources]): + sources.append(BuildSource(None, settings_module, None)) + + return old_load_graph(sources=sources, manager=manager, + old_graph=old_graph, + new_modules=new_modules) + + build.load_graph = patched_load_graph + + +def replace_apply_function_plugin_method(): + def apply_function_plugin(self, + arg_types: List[Type], + inferred_ret_type: Type, + arg_names: Optional[Sequence[Optional[str]]], + formal_to_actual: List[List[int]], + args: List[Expression], + num_formals: int, + fullname: str, + object_type: Optional[Type], + context: Context) -> Type: + """Use special case logic to infer the return type of a specific named function/method. + + Caller must ensure that a plugin hook exists. There are two different cases: + + - If object_type is None, the caller must ensure that a function hook exists + for fullname. + - If object_type is not None, the caller must ensure that a method hook exists + for fullname. + + Return the inferred return type. + """ + formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]] + formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]] + formal_arg_names = [None for _ in range(num_formals)] # type: List[Optional[str]] + for formal, actuals in enumerate(formal_to_actual): + for actual in actuals: + formal_arg_types[formal].append(arg_types[actual]) + formal_arg_exprs[formal].append(args[actual]) + if arg_names: + formal_arg_names[formal] = arg_names[actual] + + num_passed_positionals = sum([1 if name is None else 0 + for name in formal_arg_names]) + if arg_names and num_passed_positionals > 0: + object_type_info = None + if object_type is not None: + if isinstance(object_type, CallableType): + # class object, convert to corresponding Instance + object_type = object_type.ret_type + if isinstance(object_type, Instance): + # skip TypedDictType and others + object_type_info = object_type.type + + defn_arg_names = self._get_defn_arg_names(fullname, object_type=object_type_info) + if defn_arg_names: + if num_formals < len(defn_arg_names): + # self/cls argument has been passed implicitly + defn_arg_names = defn_arg_names[1:] + formal_arg_names[:num_passed_positionals] = defn_arg_names[:num_passed_positionals] + + if object_type is None: + # Apply function plugin + callback = self.plugin.get_function_hook(fullname) + assert callback is not None # Assume that caller ensures this + return callback( + FunctionContext(formal_arg_names, formal_arg_types, + inferred_ret_type, formal_arg_exprs, + context, self.chk)) + else: + # Apply method plugin + method_callback = self.plugin.get_method_hook(fullname) + assert method_callback is not None # Assume that caller ensures this + return method_callback( + MethodContext(object_type, formal_arg_names, formal_arg_types, + inferred_ret_type, formal_arg_exprs, + context, self.chk)) + + from mypy.checkexpr import ExpressionChecker + ExpressionChecker.apply_function_plugin = apply_function_plugin + diff --git a/mypy_django_plugin/plugins/objects_queryset.py b/mypy_django_plugin/plugins/objects_queryset.py index 05b0722..f8cf49d 100644 --- a/mypy_django_plugin/plugins/objects_queryset.py +++ b/mypy_django_plugin/plugins/objects_queryset.py @@ -9,19 +9,28 @@ from mypy_django_plugin import helpers def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None: - if 'objects' in ctx.cls.info.names: - return - api = cast(SemanticAnalyzerPass2, ctx.api) + # search over mro + objects_sym = ctx.cls.info.get('objects') + if objects_sym is not None: + return None - metaclass_node = ctx.cls.info.names.get('Meta') - if metaclass_node is not None: - for stmt in metaclass_node.node.defn.defs.body: + # only direct Meta class + metaclass_sym = ctx.cls.info.names.get('Meta') + # skip if abstract + if metaclass_sym is not None: + for stmt in metaclass_sym.node.defn.defs.body: if (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1 and stmt.lvalues[0].name == 'abstract'): - is_abstract = api.parse_bool(stmt.rvalue) + is_abstract = ctx.api.parse_bool(stmt.rvalue) if is_abstract: - return + return None - typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, args=[Instance(ctx.cls.info, [])]) - new_objects_node = helpers.create_new_symtable_node('objects', MDEF, instance=typ) - ctx.cls.info.names['objects'] = new_objects_node + api = cast(SemanticAnalyzerPass2, ctx.api) + typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, + args=[Instance(ctx.cls.info, [])]) + if not typ: + return None + + ctx.cls.info.names['objects'] = helpers.create_new_symtable_node('objects', + kind=MDEF, + instance=typ) diff --git a/mypy_django_plugin/plugins/postgres_fields.py b/mypy_django_plugin/plugins/postgres_fields.py index e4e3983..25e4171 100644 --- a/mypy_django_plugin/plugins/postgres_fields.py +++ b/mypy_django_plugin/plugins/postgres_fields.py @@ -1,14 +1,11 @@ from mypy.plugin import FunctionContext from mypy.types import Type -from mypy_django_plugin import helpers - def determine_type_of_array_field(ctx: FunctionContext) -> Type: - signature = helpers.get_call_signature_or_none(ctx) - if signature is None: + if 'base_field' not in ctx.arg_names: return ctx.default_return_type - _, base_field_arg_type = signature['base_field'] + base_field_arg_type = ctx.arg_types[ctx.arg_names.index('base_field')][0] return ctx.api.named_generic_type(ctx.context.callee.fullname, args=[base_field_arg_type.type.names['__get__'].type.ret_type]) diff --git a/mypy_django_plugin/plugins/related_fields.py b/mypy_django_plugin/plugins/related_fields.py index 985cfd9..7abb97c 100644 --- a/mypy_django_plugin/plugins/related_fields.py +++ b/mypy_django_plugin/plugins/related_fields.py @@ -1,23 +1,63 @@ -from typing import Optional, cast +import typing +from typing import Optional, cast, Tuple, Any +from django.apps.registry import Apps +from django.conf import Settings +from django.db import models from mypy.checker import TypeChecker -from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, MemberExpr +from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, StrExpr from mypy.plugin import FunctionContext, ClassDefContext from mypy.types import Type, CallableType, Instance, AnyType from mypy_django_plugin import helpers -def extract_to_value_type(ctx: FunctionContext) -> Optional[Instance]: - signature = helpers.get_call_signature_or_none(ctx) - if signature is None or 'to' not in signature: - return None +def get_instance_type_for_class(klass: typing.Type[models.Model], + api: TypeChecker) -> Optional[Instance]: + model_qualname = helpers.get_obj_type_name(klass) + module_name, _, class_name = model_qualname.rpartition('.') + module = api.modules.get(module_name) + if not module or class_name not in module.names: + return - arg, arg_type = signature['to'] - if not isinstance(arg_type, CallableType): - return None + sym = module.names[class_name] + return Instance(sym.node, []) - return arg_type.ret_type + +def extract_to_value_type(ctx: FunctionContext, + apps: Optional[Apps]) -> Tuple[Optional[Instance], bool]: + api = cast(TypeChecker, ctx.api) + + if 'to' not in ctx.arg_names: + return None, False + arg = ctx.args[ctx.arg_names.index('to')][0] + arg_type = ctx.arg_types[ctx.arg_names.index('to')][0] + + if isinstance(arg_type, CallableType): + return arg_type.ret_type, False + + if apps: + if isinstance(arg, StrExpr): + arg_value = arg.value + if '.' not in arg_value: + return None, False + + app_label, modelname = arg_value.lower().split('.') + try: + model_cls = apps.get_model(app_label, modelname) + except LookupError: + # no model class found + return None, False + try: + instance = get_instance_type_for_class(model_cls, api=api) + if not instance: + return None, False + return instance, True + + except AssertionError: + pass + + return None, False def extract_related_name_value(ctx: FunctionContext) -> str: @@ -30,45 +70,58 @@ def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instanc instance=new_member_instance) -def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type: - api = cast(TypeChecker, ctx.api) - outer_class_info = api.tscope.classes[-1] +class ForeignKeyHook(object): + def __init__(self, settings: Settings, apps: Apps): + self.settings = settings + self.apps = apps + + def __call__(self, ctx: FunctionContext) -> Type: + api = cast(TypeChecker, ctx.api) + outer_class_info = api.tscope.classes[-1] + + referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps) + if not referred_to: + return ctx.default_return_type + + if 'related_name' in ctx.context.arg_names: + related_name = extract_related_name_value(ctx) + queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME, + args=[Instance(outer_class_info, [])]) + if isinstance(referred_to, AnyType): + return ctx.default_return_type + + add_new_class_member(referred_to.type, + related_name, queryset_type) + if is_string_based: + return referred_to - if 'related_name' not in ctx.context.arg_names: return ctx.default_return_type - referred_to = extract_to_value_type(ctx) - if not referred_to: + +class OneToOneFieldHook(object): + def __init__(self, settings: Optional[Settings], apps: Optional[Apps]): + self.settings = settings + self.apps = apps + + def __call__(self, ctx: FunctionContext) -> Type: + if 'related_name' not in ctx.context.arg_names: + return ctx.default_return_type + + referred_to, is_string_based = extract_to_value_type(ctx, apps=self.apps) + if referred_to is None: + return ctx.default_return_type + + if 'related_name' in ctx.context.arg_names: + related_name = extract_related_name_value(ctx) + outer_class_info = ctx.api.tscope.classes[-1] + add_new_class_member(referred_to.type, related_name, + new_member_instance=Instance(outer_class_info, [])) + + if is_string_based: + return referred_to + return ctx.default_return_type - related_name = extract_related_name_value(ctx) - queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME, - args=[Instance(outer_class_info, [])]) - if isinstance(referred_to, AnyType): - # referred_to defined as string, which is unsupported for now - return ctx.default_return_type - - add_new_class_member(referred_to.type, - related_name, queryset_type) - return ctx.default_return_type - - -def set_related_name_instance_for_onetoonefield(ctx: FunctionContext) -> Type: - if 'related_name' not in ctx.context.arg_names: - return ctx.default_return_type - - referred_to = extract_to_value_type(ctx) - if referred_to is None: - return ctx.default_return_type - - related_name = extract_related_name_value(ctx) - outer_class_info = ctx.api.tscope.classes[-1] - - api = cast(TypeChecker, ctx.api) - add_new_class_member(referred_to.type, related_name, - new_member_instance=api.named_type(outer_class_info.fullname())) - return ctx.default_return_type - def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None: api = ctx.api diff --git a/mypy_django_plugin/plugins/setup_settings.py b/mypy_django_plugin/plugins/setup_settings.py index 208a26e..c866b61 100644 --- a/mypy_django_plugin/plugins/setup_settings.py +++ b/mypy_django_plugin/plugins/setup_settings.py @@ -1,7 +1,7 @@ -from typing import cast, Any +from typing import cast from django.conf import Settings -from mypy.nodes import MDEF, TypeInfo, SymbolTable +from mypy.nodes import MDEF from mypy.plugin import ClassDefContext from mypy.semanal import SemanticAnalyzerPass2 from mypy.types import Instance, AnyType, TypeOfAny @@ -9,26 +9,24 @@ from mypy.types import Instance, AnyType, TypeOfAny from mypy_django_plugin import helpers -def get_obj_type_name(value: Any) -> str: - return type(value).__module__ + '.' + type(value).__qualname__ - - class DjangoConfSettingsInitializerHook(object): def __init__(self, settings: Settings): self.settings = settings def __call__(self, ctx: ClassDefContext) -> None: api = cast(SemanticAnalyzerPass2, ctx.api) - for name, value in self.settings.__dict__.items(): - if name.isupper(): - if value is None: - ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF, - instance=api.builtin_type('builtins.object')) - continue + if self.settings: + for name, value in self.settings.__dict__.items(): + if name.isupper(): + if value is None: + # TODO: change to Optional[Any] later + ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF, + instance=api.builtin_type('builtins.object')) + continue - type_fullname = get_obj_type_name(value) - sym = api.lookup_fully_qualified_or_none(type_fullname) - if sym is not None: - args = len(sym.node.type_vars) * [AnyType(TypeOfAny.from_omitted_generics)] - ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF, - instance=Instance(sym.node, args)) + type_fullname = helpers.get_obj_type_name(type(value)) + sym = api.lookup_fully_qualified_or_none(type_fullname) + if sym is not None: + args = len(sym.node.type_vars) * [AnyType(TypeOfAny.from_omitted_generics)] + ctx.cls.info.names[name] = helpers.create_new_symtable_node(name, MDEF, + instance=Instance(sym.node, args)) diff --git a/pytest.ini b/pytest.ini index 74df75d..2772864 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,4 +4,5 @@ testpaths = test python_files = test*.py addopts = --tb=native - -v \ No newline at end of file + -v + -s \ No newline at end of file diff --git a/test/data.py b/test/data.py index b4ba267..2e5abf0 100644 --- a/test/data.py +++ b/test/data.py @@ -147,6 +147,12 @@ class DjangoDataDrivenTestCase(DataDrivenTestCase): self.old_cwd = os.getcwd() self.tmpdir = tempfile.TemporaryDirectory(prefix='mypy-test-') + tmpdir_root = os.path.join(self.tmpdir.name, 'tmp') + + new_files = [] + for path, contents in self.files: + new_files.append((path, contents.replace('', tmpdir_root))) + self.files = new_files os.chdir(self.tmpdir.name) os.mkdir(test_temp_dir) @@ -179,7 +185,7 @@ class DjangoDataDrivenTestCase(DataDrivenTestCase): self.clean_up.append((True, d)) self.clean_up.append((False, path)) - sys.path.insert(0, os.path.join(self.tmpdir.name, 'tmp')) + sys.path.insert(0, tmpdir_root) def teardown(self): if hasattr(self, 'old_environ'): diff --git a/test/helpers.py b/test/helpers.py new file mode 100644 index 0000000..b4a2916 --- /dev/null +++ b/test/helpers.py @@ -0,0 +1,253 @@ +import inspect +import os +import re +from typing import List, Callable, Optional, Tuple + +import pytest # type: ignore # no pytest in typeshed + +skip = pytest.mark.skip + +# AssertStringArraysEqual displays special line alignment helper messages if +# the first different line has at least this many characters, +MIN_LINE_LENGTH_FOR_ALIGNMENT = 5 + + +class TypecheckAssertionError(AssertionError): + def __init__(self, error_message: str, lineno: int): + self.error_message = error_message + self.lineno = lineno + + def first_line(self): + return self.__class__.__name__ + '(message="Invalid output")' + + def __str__(self): + return self.error_message + + +def _clean_up(a: List[str]) -> List[str]: + """Remove common directory prefix from all strings in a. + + This uses a naive string replace; it seems to work well enough. Also + remove trailing carriage returns. + """ + res = [] + for s in a: + prefix = os.sep + ss = s + for p in prefix, prefix.replace(os.sep, '/'): + if p != '/' and p != '//' and p != '\\' and p != '\\\\': + ss = ss.replace(p, '') + # Ignore spaces at end of line. + ss = re.sub(' +$', '', ss) + res.append(re.sub('\\r$', '', ss)) + return res + + +def _num_skipped_prefix_lines(a1: List[str], a2: List[str]) -> int: + num_eq = 0 + while num_eq < min(len(a1), len(a2)) and a1[num_eq] == a2[num_eq]: + num_eq += 1 + return max(0, num_eq - 4) + + +def _num_skipped_suffix_lines(a1: List[str], a2: List[str]) -> int: + num_eq = 0 + while (num_eq < min(len(a1), len(a2)) + and a1[-num_eq - 1] == a2[-num_eq - 1]): + num_eq += 1 + return max(0, num_eq - 4) + + +def _add_aligned_message(s1: str, s2: str, error_message: str) -> str: + """Align s1 and s2 so that the their first difference is highlighted. + + For example, if s1 is 'foobar' and s2 is 'fobar', display the + following lines: + + E: foobar + A: fobar + ^ + + If s1 and s2 are long, only display a fragment of the strings around the + first difference. If s1 is very short, do nothing. + """ + + # Seeing what went wrong is trivial even without alignment if the expected + # string is very short. In this case do nothing to simplify output. + if len(s1) < 4: + return error_message + + maxw = 72 # Maximum number of characters shown + + error_message += 'Alignment of first line difference:\n' + # sys.stderr.write('Alignment of first line difference:\n') + + trunc = False + while s1[:30] == s2[:30]: + s1 = s1[10:] + s2 = s2[10:] + trunc = True + + if trunc: + s1 = '...' + s1 + s2 = '...' + s2 + + max_len = max(len(s1), len(s2)) + extra = '' + if max_len > maxw: + extra = '...' + + # Write a chunk of both lines, aligned. + error_message += ' E: {}{}\n'.format(s1[:maxw], extra) + # sys.stderr.write(' E: {}{}\n'.format(s1[:maxw], extra)) + error_message += ' A: {}{}\n'.format(s2[:maxw], extra) + # sys.stderr.write(' A: {}{}\n'.format(s2[:maxw], extra)) + # Write an indicator character under the different columns. + error_message += ' ' + # sys.stderr.write(' ') + for j in range(min(maxw, max(len(s1), len(s2)))): + if s1[j:j + 1] != s2[j:j + 1]: + error_message += '^' + # sys.stderr.write('^') # Difference + break + else: + error_message += ' ' + # sys.stderr.write(' ') # Equal + error_message += '\n' + return error_message + # sys.stderr.write('\n') + + +def assert_string_arrays_equal(expected: List[str], actual: List[str]) -> None: + """Assert that two string arrays are equal. + + Display any differences in a human-readable form. + """ + + actual = _clean_up(actual) + error_message = '' + + if actual != expected: + num_skip_start = _num_skipped_prefix_lines(expected, actual) + num_skip_end = _num_skipped_suffix_lines(expected, actual) + + error_message += 'Expected:\n' + # sys.stderr.write('Expected:\n') + + # If omit some lines at the beginning, indicate it by displaying a line + # with '...'. + if num_skip_start > 0: + error_message += ' ...\n' + # sys.stderr.write(' ...\n') + + # Keep track of the first different line. + first_diff = -1 + + # Display only this many first characters of identical lines. + width = 75 + + for i in range(num_skip_start, len(expected) - num_skip_end): + if i >= len(actual) or expected[i] != actual[i]: + if first_diff < 0: + first_diff = i + error_message += ' {:<45} (diff)'.format(expected[i]) + # sys.stderr.write(' {:<45} (diff)'.format(expected[i])) + else: + e = expected[i] + error_message += ' ' + e[:width] + # sys.stderr.write(' ' + e[:width]) + if len(e) > width: + error_message += '...' + # sys.stderr.write('...') + error_message += '\n' + # sys.stderr.write('\n') + if num_skip_end > 0: + error_message += ' ...\n' + # sys.stderr.write(' ...\n') + + error_message += 'Actual:\n' + # sys.stderr.write('Actual:\n') + + if num_skip_start > 0: + error_message += ' ...\n' + # sys.stderr.write(' ...\n') + + for j in range(num_skip_start, len(actual) - num_skip_end): + if j >= len(expected) or expected[j] != actual[j]: + error_message += ' {:<45} (diff)'.format(actual[j]) + # sys.stderr.write(' {:<45} (diff)'.format(actual[j])) + else: + a = actual[j] + error_message += ' ' + a[:width] + # sys.stderr.write(' ' + a[:width]) + if len(a) > width: + error_message += '...' + # sys.stderr.write('...') + error_message += '\n' + # sys.stderr.write('\n') + if actual == []: + error_message += ' (empty)\n' + # sys.stderr.write(' (empty)\n') + if num_skip_end > 0: + error_message += ' ...\n' + # sys.stderr.write(' ...\n') + + error_message += '\n' + # sys.stderr.write('\n') + + if first_diff >= 0 and first_diff < len(actual) and ( + len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT + or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT): + # Display message that helps visualize the differences between two + # long lines. + error_message = _add_aligned_message(expected[first_diff], actual[first_diff], + error_message) + + first_failure = expected[first_diff] + if first_failure: + lineno = int(first_failure.split(' ')[0].strip(':').split(':')[1]) + raise TypecheckAssertionError(error_message=f'Invalid output: \n{error_message}', + lineno=lineno) + + +def build_output_line(fname: str, lnum: int, severity: str, message: str, col=None) -> str: + if col is None: + return f'{fname}:{lnum + 1}: {severity}: {message}' + else: + return f'{fname}:{lnum + 1}:{col}: {severity}: {message}' + + +def expand_errors(input_lines: List[str], fname: str) -> List[str]: + """Transform comments such as '# E: message' or + '# E:3: message' in input. + + The result is lines like 'fnam:line: error: message'. + """ + output_lines = [] + for lnum, line in enumerate(input_lines): + # The first in the split things isn't a comment + for possible_err_comment in line.split(' # ')[1:]: + m = re.search( + r'^([ENW]):((?P\d+):)? (?P.*)$', + possible_err_comment.strip()) + if m: + if m.group(1) == 'E': + severity = 'error' + elif m.group(1) == 'N': + severity = 'note' + elif m.group(1) == 'W': + severity = 'warning' + col = m.group('col') + output_lines.append(build_output_line(fname, lnum, severity, + message=m.group("message"), + col=col)) + return output_lines + + +def get_func_first_lnum(attr: Callable[..., None]) -> Optional[Tuple[int, List[str]]]: + lines, _ = inspect.getsourcelines(attr) + for lnum, line in enumerate(lines): + no_space_line = line.strip() + if f'def {attr.__name__}' in no_space_line: + return lnum, lines[lnum + 1:] + raise ValueError(f'No line "def {attr.__name__}" found') diff --git a/test/pytest_plugin.py b/test/pytest_plugin.py new file mode 100644 index 0000000..26bd4d3 --- /dev/null +++ b/test/pytest_plugin.py @@ -0,0 +1,299 @@ +import dataclasses +import inspect +import os +import sys +import tempfile +import textwrap +from contextlib import contextmanager +from pathlib import Path +from typing import Iterator, Any, Optional, cast, List, Type, Callable, Dict + +import pytest +from _pytest._code.code import ReprFileLocation, ReprEntry, ExceptionInfo +from decorator import decorate +from mypy import api as mypy_api + +from test import vistir +from test.helpers import assert_string_arrays_equal, TypecheckAssertionError, expand_errors, get_func_first_lnum + + +def reveal_type(obj: Any) -> None: + # noop method, just to get rid of "method is not resolved" errors + pass + + +def output(output_lines: str): + def decor(func: Callable[..., None]): + func.out = output_lines + + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return decorate(func, wrapper) + + return decor + + +def get_class_that_defined_method(meth) -> Type['MypyTypecheckTestCase']: + if inspect.ismethod(meth): + for cls in inspect.getmro(meth.__self__.__class__): + if cls.__dict__.get(meth.__name__) is meth: + return cls + meth = meth.__func__ # fallback to __qualname__ parsing + if inspect.isfunction(meth): + cls = getattr(inspect.getmodule(meth), + meth.__qualname__.split('.', 1)[0].rsplit('.', 1)[0]) + if issubclass(cls, MypyTypecheckTestCase): + return cls + return getattr(meth, '__objclass__', None) # handle special descriptor objects + + +def file(filename: str, make_parent_packages=False): + def decor(func: Callable[..., None]): + func.filename = filename + func.make_parent_packages = make_parent_packages + return func + + return decor + + +def env(**environ): + def decor(func: Callable[..., None]): + func.env = environ + return func + + return decor + + +@dataclasses.dataclass +class CreateFile: + sources: str + make_parent_packages: bool = False + + +class MypyTypecheckMeta(type): + def __new__(mcs, name, bases, attrs): + cls = super().__new__(mcs, name, bases, attrs) + cls.files: Dict[str, CreateFile] = {} + + for name, attr in attrs.items(): + if inspect.isfunction(attr): + filename = getattr(attr, 'filename', None) + if not filename: + continue + make_parent_packages = getattr(attr, 'make_parent_packages', False) + sources = textwrap.dedent(''.join(get_func_first_lnum(attr)[1])) + if sources.strip() == 'pass': + sources = '' + cls.files[filename] = CreateFile(sources, make_parent_packages) + + return cls + + +class MypyTypecheckTestCase(metaclass=MypyTypecheckMeta): + files = None + + def ini_file(self) -> str: + return """ +[mypy] + """ + + def _get_ini_file_contents(self) -> Optional[str]: + raw_ini_file = self.ini_file() + if not raw_ini_file: + return raw_ini_file + return raw_ini_file.strip() + '\n' + + +class TraceLastReprEntry(ReprEntry): + def toterminal(self, tw): + self.reprfileloc.toterminal(tw) + for line in self.lines: + red = line.startswith("E ") + tw.line(line, bold=True, red=red) + return + + +def fname_to_module(fpath: Path, root_path: Path) -> Optional[str]: + try: + relpath = fpath.relative_to(root_path).with_suffix('') + return str(relpath).replace(os.sep, '.') + except ValueError: + return None + + +class MypyTypecheckItem(pytest.Item): + root_directory = '/run/testdata' + + def __init__(self, + name: str, + parent: 'MypyTestsCollector', + klass: Type[MypyTypecheckTestCase], + source_code: str, + first_lineno: int, + ini_file_contents: Optional[str] = None, + expected_output_lines: Optional[List[str]] = None, + files: Optional[Dict[str, CreateFile]] = None, + custom_environment: Optional[Dict[str, Any]] = None): + super().__init__(name=name, parent=parent) + self.klass = klass + self.source_code = source_code + self.first_lineno = first_lineno + self.ini_file_contents = ini_file_contents + self.expected_output_lines = expected_output_lines + self.files = files + self.custom_environment = custom_environment + + @contextmanager + def temp_directory(self) -> Path: + with tempfile.TemporaryDirectory(prefix='mypy-pytest-', + dir=self.root_directory) as tmpdir_name: + yield Path(self.root_directory) / tmpdir_name + + def runtest(self): + with self.temp_directory() as tmpdir_path: + if not self.source_code: + return + + if self.ini_file_contents: + mypy_ini_fpath = tmpdir_path / 'mypy.ini' + mypy_ini_fpath.write_text(self.ini_file_contents) + + test_specific_modules = [] + for fname, create_file in self.files.items(): + fpath = tmpdir_path / fname + if create_file.make_parent_packages: + fpath.parent.mkdir(parents=True, exist_ok=True) + for parent in fpath.parents: + try: + parent.relative_to(tmpdir_path) + if parent != tmpdir_path: + parent_init_file = parent / '__init__.py' + parent_init_file.write_text('') + test_specific_modules.append(fname_to_module(parent, + root_path=tmpdir_path)) + except ValueError: + break + + fpath.write_text(create_file.sources) + test_specific_modules.append(fname_to_module(fpath, + root_path=tmpdir_path)) + + with vistir.temp_environ(), vistir.temp_path(): + for key, val in (self.custom_environment or {}).items(): + os.environ[key] = val + sys.path.insert(0, str(tmpdir_path)) + + mypy_cmd_options = self.prepare_mypy_cmd_options(config_file_path=mypy_ini_fpath) + main_fpath = tmpdir_path / 'main.py' + main_fpath.write_text(self.source_code) + mypy_cmd_options.append(str(main_fpath)) + + stdout, _, _ = mypy_api.run(mypy_cmd_options) + output_lines = [] + for line in stdout.splitlines(): + if ':' not in line: + continue + out_fpath, res_line = line.split(':', 1) + line = os.path.relpath(out_fpath, start=tmpdir_path) + ':' + res_line + output_lines.append(line.strip().replace('.py', '')) + + for module in test_specific_modules: + if module in sys.modules: + del sys.modules[module] + raise ValueError + assert_string_arrays_equal(expected=self.expected_output_lines, + actual=output_lines) + + def prepare_mypy_cmd_options(self, config_file_path: Path) -> List[str]: + mypy_cmd_options = [ + '--show-traceback', + '--no-silence-site-packages' + ] + python_version = '.'.join([str(part) for part in sys.version_info[:2]]) + mypy_cmd_options.append(f'--python-version={python_version}') + if self.ini_file_contents: + mypy_cmd_options.append(f'--config-file={config_file_path}') + return mypy_cmd_options + + def repr_failure(self, excinfo: ExceptionInfo) -> str: + if excinfo.errisinstance(SystemExit): + # We assume that before doing exit() (which raises SystemExit) we've printed + # enough context about what happened so that a stack trace is not useful. + # In particular, uncaught exceptions during semantic analysis or type checking + # call exit() and they already print out a stack trace. + return excinfo.exconly(tryshort=True) + elif excinfo.errisinstance(TypecheckAssertionError): + # with traceback removed + exception_repr = excinfo.getrepr(style='short') + exception_repr.reprcrash.message = '' + repr_file_location = ReprFileLocation(path=inspect.getfile(self.klass), + lineno=self.first_lineno + excinfo.value.lineno, + message='') + repr_tb_entry = TraceLastReprEntry(filelocrepr=repr_file_location, + lines=exception_repr.reprtraceback.reprentries[-1].lines[1:], + style='short', + reprlocals=None, + reprfuncargs=None) + exception_repr.reprtraceback.reprentries = [repr_tb_entry] + return exception_repr + else: + return super().repr_failure(excinfo, style='short') + + def reportinfo(self): + return self.fspath, None, get_class_qualname(self.klass) + '::' + self.name + + +def get_class_qualname(klass: type) -> str: + return klass.__module__ + '.' + klass.__name__ + + +def extract_test_output(attr: Callable[..., None]) -> List[str]: + out_data: str = getattr(attr, 'out', None) + out_lines = [] + if out_data: + for line in out_data.split('\n'): + line = line.strip() + out_lines.append(line) + return out_lines + + +class MypyTestsCollector(pytest.Class): + def get_ini_file_contents(self, contents: str) -> str: + return contents.strip() + '\n' + + def collect(self) -> Iterator[pytest.Item]: + current_testcase = cast(MypyTypecheckTestCase, self.obj()) + ini_file_contents = self.get_ini_file_contents(current_testcase.ini_file()) + for attr_name in dir(current_testcase): + if attr_name.startswith('_test_'): + attr = getattr(self.obj, attr_name) + if inspect.isfunction(attr): + first_line_lnum, source_lines = get_func_first_lnum(attr) + func_first_line_in_file = inspect.getsourcelines(attr)[1] + first_line_lnum + + output_from_decorator = extract_test_output(attr) + output_from_comments = expand_errors(source_lines, 'main') + custom_env = getattr(attr, 'env', None) + main_source_code = textwrap.dedent(''.join(source_lines)) + yield MypyTypecheckItem(name=attr_name, + parent=self, + klass=current_testcase.__class__, + source_code=main_source_code, + first_lineno=func_first_line_in_file, + ini_file_contents=ini_file_contents, + expected_output_lines=output_from_comments + + output_from_decorator, + files=current_testcase.__class__.files, + custom_environment=custom_env) + + +def pytest_pycollect_makeitem(collector: Any, name: str, obj: Any) -> Optional[MypyTestsCollector]: + # Only classes derived from DataSuite contain test cases, not the DataSuite class itself + if (isinstance(obj, type) + and issubclass(obj, MypyTypecheckTestCase) + and obj is not MypyTypecheckTestCase): + # Non-None result means this obj is a test case. + # The collect method of the returned DataSuiteCollector instance will be called later, + # with self.obj being obj. + return MypyTestsCollector(name, parent=collector) diff --git a/test/pytest_tests/__init__.py b/test/pytest_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/pytest_tests/base.py b/test/pytest_tests/base.py new file mode 100644 index 0000000..f7e3020 --- /dev/null +++ b/test/pytest_tests/base.py @@ -0,0 +1,9 @@ +from test.pytest_plugin import MypyTypecheckTestCase + + +class BaseDjangoPluginTestCase(MypyTypecheckTestCase): + def ini_file(self): + return """ +[mypy] +plugins = mypy_django_plugin.main + """ diff --git a/test/pytest_tests/test_model_fields.py b/test/pytest_tests/test_model_fields.py new file mode 100644 index 0000000..6d8f622 --- /dev/null +++ b/test/pytest_tests/test_model_fields.py @@ -0,0 +1,21 @@ +from test.pytest_plugin import reveal_type +from test.pytest_tests.base import BaseDjangoPluginTestCase + + +class TestBasicModelFields(BaseDjangoPluginTestCase): + def test_model_field_classes_present_as_primitives(self): + from django.db import models + + class User(models.Model): + id = models.AutoField(primary_key=True) + small_int = models.SmallIntegerField() + name = models.CharField(max_length=255) + slug = models.SlugField(max_length=255) + text = models.TextField() + + user = User() + reveal_type(user.id) # E: Revealed type is 'builtins.int' + reveal_type(user.small_int) # E: Revealed type is 'builtins.int' + reveal_type(user.name) # E: Revealed type is 'builtins.str' + reveal_type(user.slug) # E: Revealed type is 'builtins.str' + reveal_type(user.text) # E: Revealed type is 'builtins.str' diff --git a/test/pytest_tests/test_model_relations.py b/test/pytest_tests/test_model_relations.py index f087532..c718fce 100644 --- a/test/pytest_tests/test_model_relations.py +++ b/test/pytest_tests/test_model_relations.py @@ -1,16 +1,9 @@ -from test.pytest_plugin import MypyTypecheckTestCase, reveal_type +from test.pytest_plugin import reveal_type +from test.pytest_tests.base import BaseDjangoPluginTestCase -class BaseDjangoPluginTestCase(MypyTypecheckTestCase): - def ini_file(self): - return """ -[mypy] -plugins = mypy_django_plugin.main - """ - - -class MyTestCase(BaseDjangoPluginTestCase): - def check_foreign_key_field(self): +class TestForeignKey(BaseDjangoPluginTestCase): + def test_foreign_key_field(self): from django.db import models class Publisher(models.Model): @@ -26,7 +19,7 @@ class MyTestCase(BaseDjangoPluginTestCase): publisher = Publisher() reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]' - def check_every_foreign_key_creates_field_name_with_appended_id(self): + def test_every_foreign_key_creates_field_name_with_appended_id(self): from django.db import models class Publisher(models.Model): @@ -39,7 +32,7 @@ class MyTestCase(BaseDjangoPluginTestCase): book = Book() reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int' - def check_foreign_key_different_order_of_params(self): + def test_foreign_key_different_order_of_params(self): from django.db import models class Publisher(models.Model): @@ -47,10 +40,50 @@ class MyTestCase(BaseDjangoPluginTestCase): class Book(models.Model): publisher = models.ForeignKey(on_delete=models.CASCADE, to=Publisher, - related_name='books') + related_name='books') book = Book() reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*' publisher = Publisher() reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]' + + +class TestOneToOneField(BaseDjangoPluginTestCase): + def test_onetoone_field(self): + from django.db import models + + class User(models.Model): + pass + + class Profile(models.Model): + user = models.OneToOneField(to=User, on_delete=models.CASCADE, related_name='profile') + + profile = Profile() + reveal_type(profile.user) # E: Revealed type is 'main.User*' + + user = User() + reveal_type(user.profile) # E: Revealed type is 'main.Profile' + + def test_onetoone_field_with_underscore_id(self): + from django.db import models + + class User(models.Model): + pass + + class Profile(models.Model): + user = models.OneToOneField(to=User, on_delete=models.CASCADE, related_name='profile') + + profile = Profile() + reveal_type(profile.user_id) # E: Revealed type is 'builtins.int' + + def test_parameter_to_keyword_may_be_absent(self): + from django.db import models + + class User(models.Model): + pass + + class Profile(models.Model): + user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile') + + reveal_type(User().profile) # E: Revealed type is 'main.Profile' diff --git a/test/pytest_tests/test_objects_queryset.py b/test/pytest_tests/test_objects_queryset.py new file mode 100644 index 0000000..1f92348 --- /dev/null +++ b/test/pytest_tests/test_objects_queryset.py @@ -0,0 +1,28 @@ +from test.pytest_plugin import reveal_type, output +from test.pytest_tests.base import BaseDjangoPluginTestCase + + +class TestObjectsQueryset(BaseDjangoPluginTestCase): + def test_every_model_has_objects_queryset_available(self): + from django.db import models + + class User(models.Model): + pass + + reveal_type(User.objects) # E: Revealed type is 'django.db.models.query.QuerySet[main.User]' + + @output(""" +main:10: error: Revealed type is 'Any' +main:10: error: "Type[ModelMixin]" has no attribute "objects" + """) + def test_objects_get_returns_model_instance(self): + from django.db import models + + class ModelMixin(models.Model): + class Meta: + abstract = True + + class User(ModelMixin): + pass + + reveal_type(User.objects.get()) # E: Revealed type is 'main.User*' diff --git a/test/pytest_tests/test_parse_settings.py b/test/pytest_tests/test_parse_settings.py new file mode 100644 index 0000000..3436681 --- /dev/null +++ b/test/pytest_tests/test_parse_settings.py @@ -0,0 +1,37 @@ +from test.pytest_plugin import reveal_type, file, env +from test.pytest_tests.base import BaseDjangoPluginTestCase + + +class TestParseSettingsFromFile(BaseDjangoPluginTestCase): + @env(DJANGO_SETTINGS_MODULE='mysettings') + def test_case(self): + from django.conf import settings + + reveal_type(settings.ROOT_DIR) # E: Revealed type is 'builtins.str' + reveal_type(settings.OBJ) # E: Revealed type is 'django.utils.functional.LazyObject' + reveal_type(settings.NUMBERS) # E: Revealed type is 'builtins.list[Any]' + reveal_type(settings.DICT) # E: Revealed type is 'builtins.dict[Any, Any]' + + @file('mysettings.py') + def mysettings_py_file(self): + SECRET_KEY = 112233 + ROOT_DIR = '/etc' + NUMBERS = ['one', 'two'] + DICT = {} # type: ignore + + from django.utils.functional import LazyObject + + OBJ = LazyObject() + + +class TestSettingInitializableToNone(BaseDjangoPluginTestCase): + @env(DJANGO_SETTINGS_MODULE='mysettings') + def test_case(self): + from django.conf import settings + + reveal_type(settings.NONE_SETTING) # E: Revealed type is 'builtins.object' + + @file('mysettings.py') + def mysettings_py_file(self): + SECRET_KEY = 112233 + NONE_SETTING = None diff --git a/test/pytest_tests/test_postgres_fields.py b/test/pytest_tests/test_postgres_fields.py new file mode 100644 index 0000000..bc3c1dc --- /dev/null +++ b/test/pytest_tests/test_postgres_fields.py @@ -0,0 +1,27 @@ +from test.pytest_plugin import reveal_type +from test.pytest_tests.base import BaseDjangoPluginTestCase + + +class TestArrayField(BaseDjangoPluginTestCase): + def test_descriptor_access(self): + from django.db import models + from django.contrib.postgres.fields import ArrayField + + class User(models.Model): + array = ArrayField(base_field=models.Field()) + + user = User() + reveal_type(user.array) # E: Revealed type is 'builtins.list[Any]' + + def test_base_field_parsed_into_generic_attribute(self): + from django.db import models + from django.contrib.postgres.fields import ArrayField + + class User(models.Model): + members = ArrayField(base_field=models.IntegerField()) + members_as_text = ArrayField(base_field=models.CharField(max_length=255)) + + user = User() + reveal_type(user.members) # E: Revealed type is 'builtins.list[builtins.int*]' + reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list[builtins.str*]' + diff --git a/test/pytest_tests/test_to_attr_as_string.py b/test/pytest_tests/test_to_attr_as_string.py new file mode 100644 index 0000000..20cac2d --- /dev/null +++ b/test/pytest_tests/test_to_attr_as_string.py @@ -0,0 +1,74 @@ +from test.pytest_plugin import file, reveal_type, env +from test.pytest_tests.base import BaseDjangoPluginTestCase + + +class TestForeignKey(BaseDjangoPluginTestCase): + @env(DJANGO_SETTINGS_MODULE='mysettings') + def _test_to_parameter_could_be_specified_as_string(self): + from apps.myapp.models import Publisher + + publisher = Publisher() + reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[apps.myapp2.models.Book]' + + # @env(DJANGO_SETTINGS_MODULE='mysettings') + # def _test_creates_underscore_id_attr(self): + # from apps.myapp2.models import Book + # + # book = Book() + # reveal_type(book.publisher) # E: Revealed type is 'apps.myapp.models.Publisher' + # reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int' + + @file('mysettings.py') + def mysettings(self): + SECRET_KEY = '112233' + ROOT_DIR = '' + APPS_DIR = '/apps' + + INSTALLED_APPS = ('apps.myapp', 'apps.myapp2') + + @file('apps/myapp/models.py', make_parent_packages=True) + def apps_myapp_models(self): + from django.db import models + + class Publisher(models.Model): + pass + + @file('apps/myapp2/models.py', make_parent_packages=True) + def apps_myapp2_models(self): + from django.db import models + + class Book(models.Model): + publisher = models.ForeignKey(to='myapp.Publisher', on_delete=models.CASCADE, + related_name='books') + + +class TestOneToOneField(BaseDjangoPluginTestCase): + @env(DJANGO_SETTINGS_MODULE='mysettings') + def test_to_parameter_could_be_specified_as_string(self): + from apps.myapp.models import User + + user = User() + reveal_type(user.profile) # E: Revealed type is 'apps.myapp2.models.Profile' + + @file('mysettings.py') + def mysettings(self): + SECRET_KEY = '112233' + ROOT_DIR = '' + APPS_DIR = '/apps' + + INSTALLED_APPS = ('apps.myapp', 'apps.myapp2') + + @file('apps/myapp/models.py', make_parent_packages=True) + def apps_myapp_models(self): + from django.db import models + + class User(models.Model): + pass + + @file('apps/myapp2/models.py', make_parent_packages=True) + def apps_myapp2_models(self): + from django.db import models + + class Profile(models.Model): + user = models.OneToOneField(to='myapp.User', on_delete=models.CASCADE, + related_name='profile') diff --git a/test/testdjango.py b/test/testdjango.py index cd67f82..32e4419 100644 --- a/test/testdjango.py +++ b/test/testdjango.py @@ -14,11 +14,14 @@ MYPY_INI_PATH = ROOT_DIR / 'test' / 'plugins.ini' class DjangoTestSuite(DataSuite): files = [ - 'check-objects-queryset.test', - 'check-model-fields.test', - 'check-postgres-fields.test', - 'check-model-relations.test', - 'check-parse-settings.test' + # 'check-objects-queryset.test', + # 'check-model-fields.test', + # 'check-postgres-fields.test', + # 'check-model-relations.test', + # 'check-parse-settings.test', + # 'check-to-attr-as-string-one-to-one-field.test', + 'check-to-attr-as-string-foreign-key.test', + # 'check-foreign-key-as-string-creates-underscore-id-attr.test' ] data_prefix = str(TEST_DATA_DIR) diff --git a/test/vistir.py b/test/vistir.py new file mode 100644 index 0000000..4023716 --- /dev/null +++ b/test/vistir.py @@ -0,0 +1,43 @@ +# Borrowed from Pew. +# See https://github.com/berdario/pew/blob/master/pew/_utils.py#L82 +import os +import sys +from pathlib import Path + +from decorator import contextmanager + + +@contextmanager +def temp_environ(): + """Allow the ability to set os.environ temporarily""" + environ = dict(os.environ) + try: + yield + finally: + os.environ.clear() + os.environ.update(environ) + + +@contextmanager +def temp_path(): + """A context manager which allows the ability to set sys.path temporarily""" + path = [p for p in sys.path] + try: + yield + finally: + sys.path = [p for p in path] + + +@contextmanager +def cd(path): + """Context manager to temporarily change working directories""" + if not path: + return + prev_cwd = Path.cwd().as_posix() + if isinstance(path, Path): + path = path.as_posix() + os.chdir(str(path)) + try: + yield + finally: + os.chdir(prev_cwd)