diff --git a/stdlib/@tests/test_cases/urllib/check_parse.py b/stdlib/@tests/test_cases/urllib/check_parse.py new file mode 100644 index 000000000..f464f6341 --- /dev/null +++ b/stdlib/@tests/test_cases/urllib/check_parse.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from urllib.parse import quote, quote_plus, urlencode + +urlencode({"a": "b"}, quote_via=quote) +urlencode({b"a": b"b"}, quote_via=quote) +urlencode({"a": b"b"}, quote_via=quote) +urlencode({b"a": "b"}, quote_via=quote) +mixed_dict: dict[str | bytes, str | bytes] = {} +urlencode(mixed_dict, quote_via=quote) + +urlencode({"a": "b"}, quote_via=quote_plus) diff --git a/stdlib/urllib/parse.pyi b/stdlib/urllib/parse.pyi index f2fae0c3d..a5ed616d2 100644 --- a/stdlib/urllib/parse.pyi +++ b/stdlib/urllib/parse.pyi @@ -1,7 +1,7 @@ import sys -from collections.abc import Callable, Iterable, Mapping, Sequence +from collections.abc import Iterable, Mapping, Sequence from types import GenericAlias -from typing import Any, AnyStr, Generic, Literal, NamedTuple, TypeVar, overload +from typing import Any, AnyStr, Generic, Literal, NamedTuple, Protocol, overload, type_check_only from typing_extensions import TypeAlias __all__ = [ @@ -132,38 +132,32 @@ def urldefrag(url: str) -> DefragResult: ... @overload def urldefrag(url: bytes | bytearray | None) -> DefragResultBytes: ... -_Q = TypeVar("_Q", bound=str | Iterable[int]) +# The values are passed through `str()` (unless they are bytes), so anything is valid. _QueryType: TypeAlias = ( - Mapping[Any, Any] | Mapping[Any, Sequence[Any]] | Sequence[tuple[Any, Any]] | Sequence[tuple[Any, Sequence[Any]]] + Mapping[str, object] + | Mapping[bytes, object] + | Mapping[str | bytes, object] + | Mapping[str, Sequence[object]] + | Mapping[bytes, Sequence[object]] + | Mapping[str | bytes, Sequence[object]] + | Sequence[tuple[str | bytes, object]] + | Sequence[tuple[str | bytes, Sequence[object]]] ) -@overload +@type_check_only +class _QuoteVia(Protocol): + @overload + def __call__(self, string: str, safe: str | bytes, encoding: str, errors: str, /) -> str: ... + @overload + def __call__(self, string: bytes, safe: str | bytes, /) -> str: ... + def urlencode( query: _QueryType, doseq: bool = False, - safe: str = "", + safe: str | bytes = "", encoding: str | None = None, errors: str | None = None, - quote_via: Callable[[AnyStr, str, str, str], str] = ..., -) -> str: ... -@overload -def urlencode( - query: _QueryType, - doseq: bool, - safe: _Q, - encoding: str | None = None, - errors: str | None = None, - quote_via: Callable[[AnyStr, _Q, str, str], str] = ..., -) -> str: ... -@overload -def urlencode( - query: _QueryType, - doseq: bool = False, - *, - safe: _Q, - encoding: str | None = None, - errors: str | None = None, - quote_via: Callable[[AnyStr, _Q, str, str], str] = ..., + quote_via: _QuoteVia = ..., ) -> str: ... def urljoin(base: AnyStr, url: AnyStr | None, allow_fragments: bool = True) -> AnyStr: ... @overload