From 03b8c60a024459c5be1c6ad072fe6d44a0d3aef9 Mon Sep 17 00:00:00 2001 From: Akuli Date: Sat, 22 Apr 2023 18:28:34 +0300 Subject: [PATCH] Support `dict(foo.split() for foo in bar)` with bytes (#10072) --- stdlib/builtins.pyi | 4 +++- stdlib/collections/__init__.pyi | 2 ++ stdlib/multiprocessing/managers.pyi | 2 ++ test_cases/stdlib/builtins/check_dict.py | 4 ++++ 4 files changed, 11 insertions(+), 1 deletion(-) diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index 2c21cd95d..d49610b25 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -1032,10 +1032,12 @@ class dict(MutableMapping[_KT, _VT], Generic[_KT, _VT]): def __init__(self, __iterable: Iterable[tuple[_KT, _VT]]) -> None: ... @overload def __init__(self: dict[str, _VT], __iterable: Iterable[tuple[str, _VT]], **kwargs: _VT) -> None: ... - # Next overload is for dict(string.split(sep) for string in iterable) + # Next two overloads are for dict(string.split(sep) for string in iterable) # Cannot be Iterable[Sequence[_T]] or otherwise dict(["foo", "bar", "baz"]) is not an error @overload def __init__(self: dict[str, str], __iterable: Iterable[list[str]]) -> None: ... + @overload + def __init__(self: dict[bytes, bytes], __iterable: Iterable[list[bytes]]) -> None: ... def __new__(cls, *args: Any, **kwargs: Any) -> Self: ... def copy(self) -> dict[_KT, _VT]: ... def keys(self) -> dict_keys[_KT, _VT]: ... diff --git a/stdlib/collections/__init__.pyi b/stdlib/collections/__init__.pyi index 1a4042114..d5ca17c74 100644 --- a/stdlib/collections/__init__.pyi +++ b/stdlib/collections/__init__.pyi @@ -62,6 +62,8 @@ class UserDict(MutableMapping[_KT, _VT], Generic[_KT, _VT]): def __init__(self: UserDict[str, _VT], __iterable: Iterable[tuple[str, _VT]], **kwargs: _VT) -> None: ... @overload def __init__(self: UserDict[str, str], __iterable: Iterable[list[str]]) -> None: ... + @overload + def __init__(self: UserDict[bytes, bytes], __iterable: Iterable[list[bytes]]) -> None: ... def __len__(self) -> int: ... def __getitem__(self, key: _KT) -> _VT: ... def __setitem__(self, key: _KT, item: _VT) -> None: ... diff --git a/stdlib/multiprocessing/managers.pyi b/stdlib/multiprocessing/managers.pyi index 4ac602374..27a903fb9 100644 --- a/stdlib/multiprocessing/managers.pyi +++ b/stdlib/multiprocessing/managers.pyi @@ -197,6 +197,8 @@ class SyncManager(BaseManager): @overload def dict(self, __iterable: Iterable[list[str]]) -> DictProxy[str, str]: ... @overload + def dict(self, __iterable: Iterable[list[bytes]]) -> DictProxy[bytes, bytes]: ... + @overload def list(self, __sequence: Sequence[_T]) -> ListProxy[_T]: ... @overload def list(self) -> ListProxy[Any]: ... diff --git a/test_cases/stdlib/builtins/check_dict.py b/test_cases/stdlib/builtins/check_dict.py index 731662e63..aa920d045 100644 --- a/test_cases/stdlib/builtins/check_dict.py +++ b/test_cases/stdlib/builtins/check_dict.py @@ -50,5 +50,9 @@ i2: Iterable[tuple[str, int]] = [("a", 1), ("b", 2)] assert_type(dict(i2, arg=1), Dict[str, int]) i3: Iterable[str] = ["a.b"] +i4: Iterable[bytes] = [b"a.b"] assert_type(dict(string.split(".") for string in i3), Dict[str, str]) +assert_type(dict(string.split(b".") for string in i4), Dict[bytes, bytes]) + dict(["foo", "bar", "baz"]) # type: ignore +dict([b"foo", b"bar", b"baz"]) # type: ignore