diff --git a/django-stubs/contrib/auth/decorators.pyi b/django-stubs/contrib/auth/decorators.pyi index 954a6f9..3ecebe7 100644 --- a/django-stubs/contrib/auth/decorators.pyi +++ b/django-stubs/contrib/auth/decorators.pyi @@ -1,13 +1,21 @@ -from typing import Callable, List, Optional, Set, Union +from typing import Callable, List, Optional, Set, Union, TypeVar, overload from django.contrib.auth import REDIRECT_FIELD_NAME as REDIRECT_FIELD_NAME # noqa: F401 +from django.http.response import HttpResponseBase + +from django.contrib.auth.models import AbstractUser + +_VIEW = TypeVar("_VIEW", bound=Callable[..., HttpResponseBase]) def user_passes_test( - test_func: Callable, login_url: Optional[str] = ..., redirect_field_name: str = ... -) -> Callable: ... -def login_required( - function: Optional[Callable] = ..., redirect_field_name: str = ..., login_url: Optional[str] = ... -) -> Callable: ... + test_func: Callable[[AbstractUser], bool], login_url: Optional[str] = ..., redirect_field_name: str = ... +) -> Callable[[_VIEW], _VIEW]: ... + +# There are two ways of calling @login_required: @with(arguments) and @bare +@overload +def login_required(redirect_field_name: str = ..., login_url: Optional[str] = ...) -> Callable[[_VIEW], _VIEW]: ... +@overload +def login_required(function: _VIEW, redirect_field_name: str = ..., login_url: Optional[str] = ...) -> _VIEW: ... def permission_required( perm: Union[List[str], Set[str], str], login_url: None = ..., raise_exception: bool = ... -) -> Callable: ... +) -> Callable[[_VIEW], _VIEW]: ... diff --git a/django-stubs/db/transaction.pyi b/django-stubs/db/transaction.pyi index 7fd05a0..959479b 100644 --- a/django-stubs/db/transaction.pyi +++ b/django-stubs/db/transaction.pyi @@ -39,4 +39,11 @@ def atomic(using: _C) -> _C: ... # Decorator or context-manager with parameters @overload def atomic(using: Optional[str] = ..., savepoint: bool = ...) -> Atomic: ... -def non_atomic_requests(using: Callable = ...) -> Callable: ... + +# Bare decorator +@overload +def non_atomic_requests(using: _C) -> _C: ... + +# Decorator with arguments +@overload +def non_atomic_requests(using: Optional[str] = ...) -> Callable[[_C], _C]: ... diff --git a/django-stubs/test/utils.pyi b/django-stubs/test/utils.pyi index 18b192a..03d4da1 100644 --- a/django-stubs/test/utils.pyi +++ b/django-stubs/test/utils.pyi @@ -16,6 +16,7 @@ from typing import ( Type, Union, ContextManager, + TypeVar, ) from django.apps.registry import Apps @@ -29,6 +30,7 @@ from django.conf import LazySettings, Settings _TestClass = Type[SimpleTestCase] _DecoratedTest = Union[Callable, _TestClass] +_C = TypeVar("_C", bound=Callable) # Any callable TZ_SUPPORT: bool = ... @@ -56,7 +58,7 @@ class TestContextDecorator: def __enter__(self) -> Optional[Apps]: ... def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: ... def decorate_class(self, cls: _TestClass) -> _TestClass: ... - def decorate_callable(self, func: Callable) -> Callable: ... + def decorate_callable(self, func: _C) -> _C: ... def __call__(self, decorated: _DecoratedTest) -> Any: ... class override_settings(TestContextDecorator): @@ -146,7 +148,7 @@ def get_unique_databases_and_mirrors() -> Tuple[Dict[_Signature, _TestDatabase], def teardown_databases( old_config: Iterable[Tuple[Any, str, bool]], verbosity: int, parallel: int = ..., keepdb: bool = ... ) -> None: ... -def require_jinja2(test_func: Callable) -> Callable: ... +def require_jinja2(test_func: _C) -> _C: ... @contextmanager def register_lookup( field: Type[RegisterLookupMixin], *lookups: Type[Union[Lookup, Transform]], lookup_name: Optional[str] = ... diff --git a/django-stubs/utils/decorators.pyi b/django-stubs/utils/decorators.pyi index 9fb4fcb..3badf23 100644 --- a/django-stubs/utils/decorators.pyi +++ b/django-stubs/utils/decorators.pyi @@ -1,13 +1,16 @@ -from typing import Any, Callable, Iterable, Optional, Type, Union +from typing import Any, Callable, Iterable, Optional, Type, Union, TypeVar from django.utils.deprecation import MiddlewareMixin +from django.views.generic.base import View + +_T = TypeVar("_T", bound=Union[View, Callable]) # Any callable class classonlymethod(classmethod): ... -def method_decorator(decorator: Union[Callable, Iterable[Callable]], name: str = ...) -> Callable: ... +def method_decorator(decorator: Union[Callable, Iterable[Callable]], name: str = ...) -> Callable[[_T], _T]: ... def decorator_from_middleware_with_args(middleware_class: type) -> Callable: ... def decorator_from_middleware(middleware_class: type) -> Callable: ... -def available_attrs(fn: Any): ... +def available_attrs(fn: Callable): ... def make_middleware_decorator(middleware_class: Type[MiddlewareMixin]) -> Callable: ... class classproperty: diff --git a/scripts/enabled_test_modules.py b/scripts/enabled_test_modules.py index cb85749..d67c0cf 100644 --- a/scripts/enabled_test_modules.py +++ b/scripts/enabled_test_modules.py @@ -170,7 +170,8 @@ IGNORED_ERRORS = { 'Incompatible types in assignment (expression has type "Optional[Any]", variable has type "FloatModel")' ], 'decorators': [ - '"Type[object]" has no attribute "method"' + '"Type[object]" has no attribute "method"', + 'Value of type variable "_T" of function cannot be "descriptor_wrapper"' ], 'expressions_window': [ 'has incompatible type "str"' diff --git a/test-data/typecheck/contrib/auth/test_decorators.yml b/test-data/typecheck/contrib/auth/test_decorators.yml new file mode 100644 index 0000000..616e007 --- /dev/null +++ b/test-data/typecheck/contrib/auth/test_decorators.yml @@ -0,0 +1,43 @@ +- case: login_required_bare + main: | + from django.contrib.auth.decorators import login_required + @login_required + def view_func(request): ... + reveal_type(view_func) # N: Revealed type is 'def (request: Any) -> Any' +- case: login_required_fancy + main: | + from django.contrib.auth.decorators import login_required + from django.core.handlers.wsgi import WSGIRequest + from django.http import HttpResponse + @login_required(redirect_field_name='a', login_url='b') + def view_func(request: WSGIRequest, arg: str) -> HttpResponse: ... + reveal_type(view_func) # N: Revealed type is 'def (request: django.core.handlers.wsgi.WSGIRequest, arg: builtins.str) -> django.http.response.HttpResponse' +- case: login_required_weird + main: | + from django.contrib.auth.decorators import login_required + # This is non-conventional usage, but covered in Django tests, so we allow it. + def view_func(request): ... + wrapped_view = login_required(view_func, redirect_field_name='a', login_url='b') + reveal_type(wrapped_view) # N: Revealed type is 'def (request: Any) -> Any' +- case: login_required_incorrect_return + main: | + from django.contrib.auth.decorators import login_required + @login_required() # E: Value of type variable "_VIEW" of function cannot be "Callable[[Any], str]" + def view_func2(request) -> str: ... +- case: user_passes_test + main: | + from django.contrib.auth.decorators import user_passes_test + @user_passes_test(lambda u: u.username.startswith('super')) + def view_func(request): ... + reveal_type(view_func) # N: Revealed type is 'def (request: Any) -> Any' +- case: user_passes_test_bare_is_error + main: | + from django.http.response import HttpResponse + from django.contrib.auth.decorators import user_passes_test + @user_passes_test # E: Argument 1 to "user_passes_test" has incompatible type "Callable[[Any], HttpResponse]"; expected "Callable[[AbstractUser], bool]" + def view_func(request) -> HttpResponse: ... +- case: permission_required + main: | + from django.contrib.auth.decorators import permission_required + @permission_required('polls.can_vote') + def view_func(request): ... diff --git a/test-data/typecheck/db/test_transaction.yml b/test-data/typecheck/db/test_transaction.yml new file mode 100644 index 0000000..446dfc4 --- /dev/null +++ b/test-data/typecheck/db/test_transaction.yml @@ -0,0 +1,28 @@ +- case: atomic_bare + main: | + from django.db.transaction import atomic + @atomic + def func(x: int) -> list: ... + reveal_type(func) # N: Revealed type is 'def (x: builtins.int) -> builtins.list[Any]' +- case: atomic_args + main: | + from django.db.transaction import atomic + @atomic(using='bla', savepoint=False) + def func(x: int) -> list: ... + reveal_type(func) # N: Revealed type is 'def (x: builtins.int) -> builtins.list[Any]' +- case: non_atomic_requests_bare + main: | + from django.db.transaction import non_atomic_requests + @non_atomic_requests + def view_func(request): ... + reveal_type(view_func) # N: Revealed type is 'def (request: Any) -> Any' + +- case: non_atomic_requests_args + main: | + from django.http.request import HttpRequest + from django.http.response import HttpResponse + from django.db.transaction import non_atomic_requests + @non_atomic_requests + def view_func(request: HttpRequest, arg: str) -> HttpResponse: ... + reveal_type(view_func) # N: Revealed type is 'def (request: django.http.request.HttpRequest, arg: builtins.str) -> django.http.response.HttpResponse' + diff --git a/test-data/typecheck/utils/test_decorators.yml b/test-data/typecheck/utils/test_decorators.yml new file mode 100644 index 0000000..b347938 --- /dev/null +++ b/test-data/typecheck/utils/test_decorators.yml @@ -0,0 +1,20 @@ +- case: method_decorator_class + main: | + from django.views.generic.base import View + from django.utils.decorators import method_decorator + from django.contrib.auth.decorators import login_required + @method_decorator(login_required, name='dispatch') + class TestView(View): ... + reveal_type(TestView()) # N: Revealed type is 'main.TestView' +- case: method_decorator_function + main: | + from django.views.generic.base import View + from django.utils.decorators import method_decorator + from django.contrib.auth.decorators import login_required + from django.http.response import HttpResponse + from django.http.request import HttpRequest + class TestView(View): + @method_decorator(login_required) + def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: + return super().dispatch(request, *args, **kwargs) + reveal_type(dispatch) # N: Revealed type is 'def (self: main.TestView, request: django.http.request.HttpRequest, *args: Any, **kwargs: Any) -> django.http.response.HttpResponse'