Fix type annotations for django.utils.safestring (#179)

* Add/fix types for django.utils.safestring.mark_safe

Django code ref: 964dd4f4f2/django/utils/safestring.py (L71-L84)

* add generic annotations for mark_safe, remove SafeBytes as it is basically deprecated

Co-authored-by: Daniel Hahler <github@thequod.de>
This commit is contained in:
Maxim Kurnikov
2019-09-23 21:16:24 +03:00
parent afcd0d9293
commit 7407b93151
3 changed files with 45 additions and 15 deletions

View File

@@ -1,9 +1,7 @@
from html.parser import HTMLParser from html.parser import HTMLParser
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Iterator, List, Optional, Tuple, Union
from django.db.models.base import Model from django.utils.safestring import SafeText
from django.db.models.fields.files import FieldFile
from django.utils.safestring import SafeText, mark_safe as mark_safe
TRAILING_PUNCTUATION_CHARS: str TRAILING_PUNCTUATION_CHARS: str
WRAPPING_PUNCTUATION: Any WRAPPING_PUNCTUATION: Any
@@ -13,15 +11,15 @@ word_split_re: Any
simple_url_re: Any simple_url_re: Any
simple_url_2_re: Any simple_url_2_re: Any
def escape(text: Optional[Union[Model, FieldFile, int, str]]) -> SafeText: ... def escape(text: Any) -> SafeText: ...
def escapejs(value: str) -> SafeText: ... def escapejs(value: Any) -> SafeText: ...
def json_script(value: Union[Dict[str, str], str], element_id: str) -> SafeText: ... def json_script(value: Any, element_id: str) -> SafeText: ...
def conditional_escape(text: Any) -> str: ... def conditional_escape(text: Any) -> str: ...
def format_html(format_string: str, *args: Any, **kwargs: Any) -> SafeText: ... def format_html(format_string: str, *args: Any, **kwargs: Any) -> SafeText: ...
def format_html_join( def format_html_join(
sep: str, format_string: str, args_generator: Union[Iterator[Any], List[Tuple[str]]] sep: str, format_string: str, args_generator: Union[Iterator[Any], List[Tuple[str]]]
) -> SafeText: ... ) -> SafeText: ...
def linebreaks(value: str, autoescape: bool = ...) -> str: ... def linebreaks(value: Any, autoescape: bool = ...) -> str: ...
class MLStripper(HTMLParser): class MLStripper(HTMLParser):
fed: Any = ... fed: Any = ...

View File

@@ -1,14 +1,26 @@
from typing import Any from typing import TypeVar, overload, Callable, Any
_SD = TypeVar("_SD", bound="SafeData")
class SafeData: class SafeData:
def __html__(self) -> SafeText: ... def __html__(self: _SD) -> _SD: ...
class SafeBytes(bytes, SafeData):
def __add__(self, rhs: Any): ...
class SafeText(str, SafeData): class SafeText(str, SafeData):
@overload
def __add__(self, rhs: SafeText) -> SafeText: ...
@overload
def __add__(self, rhs: str) -> str: ... def __add__(self, rhs: str) -> str: ...
@overload
def __iadd__(self, rhs: SafeText) -> SafeText: ...
@overload
def __iadd__(self, rhs: str) -> str: ...
SafeString = SafeText SafeString = SafeText
def mark_safe(s: Any) -> Any: ... _C = TypeVar("_C", bound=Callable)
@overload
def mark_safe(s: _SD) -> _SD: ...
@overload
def mark_safe(s: _C) -> _C: ...
@overload
def mark_safe(s: Any) -> SafeText: ...

View File

@@ -39,4 +39,24 @@
# Ensure that the method's type is preserved # Ensure that the method's type is preserved
reveal_type(ClassWithAtomicMethod().atomic_method1) # N: Revealed type is 'def (abc: builtins.int) -> builtins.str' reveal_type(ClassWithAtomicMethod().atomic_method1) # N: Revealed type is 'def (abc: builtins.int) -> builtins.str'
# Ensure that the method's type is preserved # Ensure that the method's type is preserved
reveal_type(ClassWithAtomicMethod().atomic_method3) # N: Revealed type is 'def (myparam: builtins.str) -> builtins.int' reveal_type(ClassWithAtomicMethod().atomic_method3) # N: Revealed type is 'def (myparam: builtins.str) -> builtins.int'
- case: mark_safe_decorator_and_function
main: |
from django.utils.safestring import mark_safe
s = 'hello'
reveal_type(mark_safe(s)) # N: Revealed type is 'django.utils.safestring.SafeText'
reveal_type(mark_safe(s) + mark_safe(s)) # N: Revealed type is 'django.utils.safestring.SafeText'
reveal_type(s + mark_safe(s)) # N: Revealed type is 'builtins.str'
s += mark_safe(s)
reveal_type(s) # N: Revealed type is 'builtins.str'
ms = mark_safe(s)
ms += mark_safe(s)
reveal_type(ms) # N: Revealed type is 'django.utils.safestring.SafeText'
@mark_safe
def func(s: str) -> str:
pass
reveal_type(func) # N: Revealed type is 'def (s: builtins.str) -> builtins.str'