From a1334a70b90ccfc2928c44022c30eb24412d98c1 Mon Sep 17 00:00:00 2001 From: Tim Martin Date: Wed, 20 Jan 2021 20:11:02 +0000 Subject: [PATCH] Stricter return type annotations for template.Library (#541) * Stricter return type annotations for template.Library * Add some unit tests for the template library decorators --- django-stubs/template/library.pyi | 37 +++++---- tests/typecheck/template/test_library.yml | 96 +++++++++++++++++++++++ 2 files changed, 118 insertions(+), 15 deletions(-) create mode 100644 tests/typecheck/template/test_library.yml diff --git a/django-stubs/template/library.pyi b/django-stubs/template/library.pyi index dca59e2..25fc7e6 100644 --- a/django-stubs/template/library.pyi +++ b/django-stubs/template/library.pyi @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, overload from django.template.base import FilterExpression, Parser, Origin, Token from django.template.context import Context @@ -8,31 +8,38 @@ from .base import Node, Template class InvalidTemplateLibrary(Exception): ... +_C = TypeVar("_C", bound=Callable[..., Any]) + class Library: filters: Dict[str, Callable] = ... tags: Dict[str, Callable] = ... def __init__(self) -> None: ... - def tag( - self, name: Optional[Union[Callable, str]] = ..., compile_function: Optional[Union[Callable, str]] = ... - ) -> Callable: ... - def tag_function(self, func: Callable) -> Callable: ... - def filter( - self, - name: Optional[Union[Callable, str]] = ..., - filter_func: Optional[Union[Callable, str]] = ..., - **flags: Any - ) -> Callable: ... - def filter_function(self, func: Callable, **flags: Any) -> Callable: ... + @overload + def tag(self, name: _C) -> _C: ... + @overload + def tag(self, name: str, compile_function: _C) -> _C: ... + @overload + def tag(self, name: Optional[str] = ..., compile_function: None = ...) -> Callable[[_C], _C]: ... + def tag_function(self, func: _C) -> _C: ... + @overload + def filter(self, name: _C, filter_func: None = ..., **flags: Any) -> _C: ... + @overload + def filter(self, name: Optional[str], filter_func: _C, **flags: Any) -> _C: ... + @overload + def filter(self, name: Optional[str] = ..., filter_func: None = ..., **flags: Any) -> Callable[[_C], _C]: ... + @overload + def simple_tag(self, func: _C) -> _C: ... + @overload def simple_tag( - self, func: Optional[Union[Callable, str]] = ..., takes_context: Optional[bool] = ..., name: Optional[str] = ... - ) -> Callable: ... + self, takes_context: Optional[bool] = ..., name: Optional[str] = ... + ) -> Callable[[_C], _C]: ... def inclusion_tag( self, filename: Union[Template, str], func: None = ..., takes_context: Optional[bool] = ..., name: Optional[str] = ..., - ) -> Callable: ... + ) -> Callable[[_C], _C]: ... class TagHelperNode(Node): func: Any = ... diff --git a/tests/typecheck/template/test_library.yml b/tests/typecheck/template/test_library.yml new file mode 100644 index 0000000..c0735ce --- /dev/null +++ b/tests/typecheck/template/test_library.yml @@ -0,0 +1,96 @@ +- case: register_filter_unnamed + main: | + from django import template + register = template.Library() + + @register.filter + def lower(value: str) -> str: + return value.lower() + + reveal_type(lower) # N: Revealed type is 'def (value: builtins.str) -> builtins.str' + +- case: register_filter_named + main: | + from django import template + register = template.Library() + + @register.filter(name="tolower") + def lower(value: str) -> str: + return value.lower() + + reveal_type(lower) # N: Revealed type is 'def (value: builtins.str) -> builtins.str' + +- case: register_simple_tag_no_args + main: | + import datetime + from django import template + register = template.Library() + + @register.simple_tag + def current_time(format_string: str) -> str: + return datetime.datetime.now().strftime(format_string) + + reveal_type(current_time) # N: Revealed type is 'def (format_string: builtins.str) -> builtins.str' + +- case: register_simple_tag_context + main: | + from django import template + from typing import Dict, Any + register = template.Library() + + @register.simple_tag(takes_context=True) + def current_time(context: Dict[str, Any], format_string: str) -> str: + timezone = context['timezone'] + return "test" + + reveal_type(current_time) # N: Revealed type is 'def (context: builtins.dict[builtins.str, Any], format_string: builtins.str) -> builtins.str' + +- case: register_simple_tag_named + main: | + from django import template + from typing import Dict, Any + register = template.Library() + + @register.simple_tag(name='minustwo') + def some_function(value: int) -> int: + return value - 2 + + reveal_type(some_function) # N: Revealed type is 'def (value: builtins.int) -> builtins.int' + +- case: register_tag_no_args + main: | + from django import template + from django.template.base import Parser, Token + from django.template.defaulttags import CycleNode + register = template.Library() + + @register.tag + def cycle(parser: Parser, token: Token) -> CycleNode: + return CycleNode([]) + + reveal_type(cycle) # N: Revealed type is 'def (parser: django.template.base.Parser, token: django.template.base.Token) -> django.template.defaulttags.CycleNode' + +- case: register_tag_named + main: | + from django import template + from django.template.base import Parser, Token + from django.template.defaulttags import CycleNode + register = template.Library() + + @register.tag(name="cycle") + def cycle_impl(parser: Parser, token: Token) -> CycleNode: + return CycleNode([]) + + reveal_type(cycle_impl) # N: Revealed type is 'def (parser: django.template.base.Parser, token: django.template.base.Token) -> django.template.defaulttags.CycleNode' + +- case: register_inclusion_tag + main: | + from django import template + from typing import List + register = template.Library() + + @register.inclusion_tag('results.html') + def format_results(results: List[str]) -> str: + return ', '.join(results) + + reveal_type(format_results) # N: Revealed type is 'def (results: builtins.list[builtins.str]) -> builtins.str'