From 7407b9315109b33ae8c6066c4be7ac1920919f0e Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Mon, 23 Sep 2019 21:16:24 +0300 Subject: [PATCH] Fix type annotations for django.utils.safestring (#179) * Add/fix types for django.utils.safestring.mark_safe Django code ref: https://github.com/django/django/blob/964dd4f4f208722d8993a35c1ff047d353cea1ea/django/utils/safestring.py#L71-L84 * add generic annotations for mark_safe, remove SafeBytes as it is basically deprecated Co-authored-by: Daniel Hahler --- django-stubs/utils/html.pyi | 14 ++++++-------- django-stubs/utils/safestring.pyi | 24 ++++++++++++++++++------ test-data/typecheck/test_helpers.yml | 22 +++++++++++++++++++++- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/django-stubs/utils/html.pyi b/django-stubs/utils/html.pyi index 622327a..314dcaa 100644 --- a/django-stubs/utils/html.pyi +++ b/django-stubs/utils/html.pyi @@ -1,9 +1,7 @@ from html.parser import HTMLParser -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Iterator, List, Optional, Tuple, Union -from django.db.models.base import Model -from django.db.models.fields.files import FieldFile -from django.utils.safestring import SafeText, mark_safe as mark_safe +from django.utils.safestring import SafeText TRAILING_PUNCTUATION_CHARS: str WRAPPING_PUNCTUATION: Any @@ -13,15 +11,15 @@ word_split_re: Any simple_url_re: Any simple_url_2_re: Any -def escape(text: Optional[Union[Model, FieldFile, int, str]]) -> SafeText: ... -def escapejs(value: str) -> SafeText: ... -def json_script(value: Union[Dict[str, str], str], element_id: str) -> SafeText: ... +def escape(text: Any) -> SafeText: ... +def escapejs(value: Any) -> SafeText: ... +def json_script(value: Any, element_id: str) -> SafeText: ... def conditional_escape(text: Any) -> str: ... def format_html(format_string: str, *args: Any, **kwargs: Any) -> SafeText: ... def format_html_join( sep: str, format_string: str, args_generator: Union[Iterator[Any], List[Tuple[str]]] ) -> SafeText: ... -def linebreaks(value: str, autoescape: bool = ...) -> str: ... +def linebreaks(value: Any, autoescape: bool = ...) -> str: ... class MLStripper(HTMLParser): fed: Any = ... diff --git a/django-stubs/utils/safestring.pyi b/django-stubs/utils/safestring.pyi index 52db7bc..ac52284 100644 --- a/django-stubs/utils/safestring.pyi +++ b/django-stubs/utils/safestring.pyi @@ -1,14 +1,26 @@ -from typing import Any +from typing import TypeVar, overload, Callable, Any + +_SD = TypeVar("_SD", bound="SafeData") class SafeData: - def __html__(self) -> SafeText: ... - -class SafeBytes(bytes, SafeData): - def __add__(self, rhs: Any): ... + def __html__(self: _SD) -> _SD: ... class SafeText(str, SafeData): + @overload + def __add__(self, rhs: SafeText) -> SafeText: ... + @overload def __add__(self, rhs: str) -> str: ... + @overload + def __iadd__(self, rhs: SafeText) -> SafeText: ... + @overload + def __iadd__(self, rhs: str) -> str: ... SafeString = SafeText -def mark_safe(s: Any) -> Any: ... +_C = TypeVar("_C", bound=Callable) +@overload +def mark_safe(s: _SD) -> _SD: ... +@overload +def mark_safe(s: _C) -> _C: ... +@overload +def mark_safe(s: Any) -> SafeText: ... diff --git a/test-data/typecheck/test_helpers.yml b/test-data/typecheck/test_helpers.yml index a2ea0c6..f82ceea 100644 --- a/test-data/typecheck/test_helpers.yml +++ b/test-data/typecheck/test_helpers.yml @@ -39,4 +39,24 @@ # Ensure that the method's type is preserved reveal_type(ClassWithAtomicMethod().atomic_method1) # N: Revealed type is 'def (abc: builtins.int) -> builtins.str' # Ensure that the method's type is preserved - reveal_type(ClassWithAtomicMethod().atomic_method3) # N: Revealed type is 'def (myparam: builtins.str) -> builtins.int' \ No newline at end of file + reveal_type(ClassWithAtomicMethod().atomic_method3) # N: Revealed type is 'def (myparam: builtins.str) -> builtins.int' + + +- case: mark_safe_decorator_and_function + main: | + from django.utils.safestring import mark_safe + s = 'hello' + reveal_type(mark_safe(s)) # N: Revealed type is 'django.utils.safestring.SafeText' + reveal_type(mark_safe(s) + mark_safe(s)) # N: Revealed type is 'django.utils.safestring.SafeText' + reveal_type(s + mark_safe(s)) # N: Revealed type is 'builtins.str' + + s += mark_safe(s) + reveal_type(s) # N: Revealed type is 'builtins.str' + ms = mark_safe(s) + ms += mark_safe(s) + reveal_type(ms) # N: Revealed type is 'django.utils.safestring.SafeText' + + @mark_safe + def func(s: str) -> str: + pass + reveal_type(func) # N: Revealed type is 'def (s: builtins.str) -> builtins.str'