From dbd3ad356ef3bceae506fb71a29a48b6192213ba Mon Sep 17 00:00:00 2001 From: Neil Mitchell Date: Wed, 2 Jul 2025 11:24:32 +0100 Subject: [PATCH] Make Mapping.get(default) more constrained (#14360) --- stdlib/collections/__init__.pyi | 4 ++++ stdlib/importlib/metadata/__init__.pyi | 2 ++ stdlib/types.pyi | 4 +++- stdlib/typing.pyi | 4 +++- stdlib/weakref.pyi | 4 ++++ stubs/WebOb/webob/cookies.pyi | 2 ++ stubs/WebOb/webob/multidict.pyi | 4 +++- stubs/boltons/boltons/cacheutils.pyi | 2 ++ stubs/grpcio/grpc/aio/__init__.pyi | 4 +++- stubs/inifile/inifile.pyi | 4 +++- stubs/oauthlib/oauthlib/common.pyi | 5 ++++- stubs/protobuf/google/protobuf/internal/containers.pyi | 8 ++++++-- 12 files changed, 39 insertions(+), 8 deletions(-) diff --git a/stdlib/collections/__init__.pyi b/stdlib/collections/__init__.pyi index b9e4f84ec..bc33d91ca 100644 --- a/stdlib/collections/__init__.pyi +++ b/stdlib/collections/__init__.pyi @@ -108,6 +108,8 @@ class UserDict(MutableMapping[_KT, _VT]): @overload def get(self, key: _KT, default: None = None) -> _VT | None: ... @overload + def get(self, key: _KT, default: _VT) -> _VT: ... + @overload def get(self, key: _KT, default: _T) -> _VT | _T: ... class UserList(MutableSequence[_T]): @@ -452,6 +454,8 @@ class ChainMap(MutableMapping[_KT, _VT]): @overload def get(self, key: _KT, default: None = None) -> _VT | None: ... @overload + def get(self, key: _KT, default: _VT) -> _VT: ... + @overload def get(self, key: _KT, default: _T) -> _VT | _T: ... def __missing__(self, key: _KT) -> _VT: ... # undocumented def __bool__(self) -> bool: ... diff --git a/stdlib/importlib/metadata/__init__.pyi b/stdlib/importlib/metadata/__init__.pyi index 15d8b50b0..789878382 100644 --- a/stdlib/importlib/metadata/__init__.pyi +++ b/stdlib/importlib/metadata/__init__.pyi @@ -140,6 +140,8 @@ if sys.version_info >= (3, 10) and sys.version_info < (3, 12): @overload def get(self, name: _KT, default: None = None) -> _VT | None: ... @overload + def get(self, name: _KT, default: _VT) -> _VT: ... + @overload def get(self, name: _KT, default: _T) -> _VT | _T: ... def __iter__(self) -> Iterator[_KT]: ... def __contains__(self, *args: object) -> bool: ... diff --git a/stdlib/types.pyi b/stdlib/types.pyi index 582cb6534..44bd3eeb3 100644 --- a/stdlib/types.pyi +++ b/stdlib/types.pyi @@ -323,7 +323,9 @@ class MappingProxyType(Mapping[_KT, _VT_co]): @overload def get(self, key: _KT, /) -> _VT_co | None: ... @overload - def get(self, key: _KT, default: _VT_co | _T2, /) -> _VT_co | _T2: ... + def get(self, key: _KT, default: _VT_co, /) -> _VT_co: ... # type: ignore[misc] # pyright: ignore[reportGeneralTypeIssues] # Covariant type as parameter + @overload + def get(self, key: _KT, default: _T2, /) -> _VT_co | _T2: ... def __class_getitem__(cls, item: Any, /) -> GenericAlias: ... def __reversed__(self) -> Iterator[_KT]: ... def __or__(self, value: Mapping[_T1, _T2], /) -> dict[_KT | _T1, _VT_co | _T2]: ... diff --git a/stdlib/typing.pyi b/stdlib/typing.pyi index b1c93dfe7..d296c8d92 100644 --- a/stdlib/typing.pyi +++ b/stdlib/typing.pyi @@ -745,7 +745,9 @@ class Mapping(Collection[_KT], Generic[_KT, _VT_co]): @overload def get(self, key: _KT, /) -> _VT_co | None: ... @overload - def get(self, key: _KT, /, default: _VT_co | _T) -> _VT_co | _T: ... + def get(self, key: _KT, /, default: _VT_co) -> _VT_co: ... # type: ignore[misc] # pyright: ignore[reportGeneralTypeIssues] # Covariant type as parameter + @overload + def get(self, key: _KT, /, default: _T) -> _VT_co | _T: ... def items(self) -> ItemsView[_KT, _VT_co]: ... def keys(self) -> KeysView[_KT]: ... def values(self) -> ValuesView[_VT_co]: ... diff --git a/stdlib/weakref.pyi b/stdlib/weakref.pyi index 593eb4615..334fab7e7 100644 --- a/stdlib/weakref.pyi +++ b/stdlib/weakref.pyi @@ -99,6 +99,8 @@ class WeakValueDictionary(MutableMapping[_KT, _VT]): @overload def get(self, key: _KT, default: None = None) -> _VT | None: ... @overload + def get(self, key: _KT, default: _VT) -> _VT: ... + @overload def get(self, key: _KT, default: _T) -> _VT | _T: ... # These are incompatible with Mapping def keys(self) -> Iterator[_KT]: ... # type: ignore[override] @@ -149,6 +151,8 @@ class WeakKeyDictionary(MutableMapping[_KT, _VT]): @overload def get(self, key: _KT, default: None = None) -> _VT | None: ... @overload + def get(self, key: _KT, default: _VT) -> _VT: ... + @overload def get(self, key: _KT, default: _T) -> _VT | _T: ... # These are incompatible with Mapping def keys(self) -> Iterator[_KT]: ... # type: ignore[override] diff --git a/stubs/WebOb/webob/cookies.pyi b/stubs/WebOb/webob/cookies.pyi index 98914a2a7..8a9632102 100644 --- a/stubs/WebOb/webob/cookies.pyi +++ b/stubs/WebOb/webob/cookies.pyi @@ -37,6 +37,8 @@ class RequestCookies(MutableMapping[str, str]): @overload def get(self, name: str, default: None = None) -> str | None: ... @overload + def get(self, name: str, default: str) -> str: ... + @overload def get(self, name: str, default: _T) -> str | _T: ... def __delitem__(self, name: str) -> None: ... def keys(self) -> KeysView[str]: ... diff --git a/stubs/WebOb/webob/multidict.pyi b/stubs/WebOb/webob/multidict.pyi index 1b5f3a046..6215d0079 100644 --- a/stubs/WebOb/webob/multidict.pyi +++ b/stubs/WebOb/webob/multidict.pyi @@ -48,7 +48,9 @@ class MultiDict(MutableMapping[_KT, _VT]): def __setitem__(self, key: _KT, value: _VT) -> None: ... def add(self, key: _KT, value: _VT) -> None: ... @overload - def get(self, key: _KT) -> _VT | None: ... + def get(self, key: _KT, default: None = None) -> _VT | None: ... + @overload + def get(self, key: _KT, default: _VT) -> _VT: ... @overload def get(self, key: _KT, default: _T) -> _VT | _T: ... def getall(self, key: _KT) -> list[_VT]: ... diff --git a/stubs/boltons/boltons/cacheutils.pyi b/stubs/boltons/boltons/cacheutils.pyi index 751aaec3b..40e786e61 100644 --- a/stubs/boltons/boltons/cacheutils.pyi +++ b/stubs/boltons/boltons/cacheutils.pyi @@ -26,6 +26,8 @@ class LRI(dict[_KT, _VT]): @overload def get(self, key: _KT, default: None = None) -> _VT | None: ... @overload + def get(self, key: _KT, default: _VT) -> _VT: ... + @overload def get(self, key: _KT, default: _T) -> _T | _VT: ... def __delitem__(self, key: _KT) -> None: ... @overload diff --git a/stubs/grpcio/grpc/aio/__init__.pyi b/stubs/grpcio/grpc/aio/__init__.pyi index a27b8eeca..04d98877f 100644 --- a/stubs/grpcio/grpc/aio/__init__.pyi +++ b/stubs/grpcio/grpc/aio/__init__.pyi @@ -435,7 +435,9 @@ class Metadata(Mapping[_MetadataKey, _MetadataValue]): def delete_all(self, key: _MetadataKey) -> None: ... def __iter__(self) -> Iterator[_MetadataKey]: ... @overload - def get(self, key: _MetadataKey) -> _MetadataValue | None: ... + def get(self, key: _MetadataKey, default: None = None) -> _MetadataValue | None: ... + @overload + def get(self, key: _MetadataKey, default: _MetadataValue) -> _MetadataValue: ... @overload def get(self, key: _MetadataKey, default: _T) -> _MetadataValue | _T: ... def get_all(self, key: _MetadataKey) -> list[_MetadataValue]: ... diff --git a/stubs/inifile/inifile.pyi b/stubs/inifile/inifile.pyi index 8dd54936e..c64a3168d 100644 --- a/stubs/inifile/inifile.pyi +++ b/stubs/inifile/inifile.pyi @@ -69,7 +69,9 @@ class IniData(MutableMapping[str, str]): def to_dict(self) -> dict[str, str]: ... def __len__(self) -> int: ... @overload - def get(self, name: str) -> str | None: ... + def get(self, name: str, default: None = None) -> str | None: ... + @overload + def get(self, name: str, default: str) -> str: ... @overload def get(self, name: str, default: _T) -> str | _T: ... @overload diff --git a/stubs/oauthlib/oauthlib/common.pyi b/stubs/oauthlib/oauthlib/common.pyi index 9fd6b272e..15112d518 100644 --- a/stubs/oauthlib/oauthlib/common.pyi +++ b/stubs/oauthlib/oauthlib/common.pyi @@ -52,7 +52,10 @@ class CaseInsensitiveDict(dict[str, Incomplete]): def __contains__(self, k: object) -> bool: ... def __delitem__(self, k: str) -> None: ... def __getitem__(self, k: str): ... - def get(self, k: str, default=None) -> Incomplete | None: ... + @overload + def get(self, k: str, default: None = None) -> Incomplete | None: ... + @overload + def get(self, k: str, default): ... def __setitem__(self, k: str, v) -> None: ... def update(self, *args, **kwargs) -> None: ... diff --git a/stubs/protobuf/google/protobuf/internal/containers.pyi b/stubs/protobuf/google/protobuf/internal/containers.pyi index 752613716..1431c3b13 100644 --- a/stubs/protobuf/google/protobuf/internal/containers.pyi +++ b/stubs/protobuf/google/protobuf/internal/containers.pyi @@ -72,7 +72,9 @@ class ScalarMap(MutableMapping[_K, _ScalarV]): @overload def get(self, key: _K, default: None = None) -> _ScalarV | None: ... @overload - def get(self, key: _K, default: _ScalarV | _T) -> _ScalarV | _T: ... + def get(self, key: _K, default: _ScalarV) -> _ScalarV: ... + @overload + def get(self, key: _K, default: _T) -> _ScalarV | _T: ... def setdefault(self, key: _K, value: _ScalarV | None = None) -> _ScalarV: ... def MergeFrom(self, other: Self): ... def InvalidateIterators(self) -> None: ... @@ -95,7 +97,9 @@ class MessageMap(MutableMapping[_K, _MessageV]): @overload def get(self, key: _K, default: None = None) -> _MessageV | None: ... @overload - def get(self, key: _K, default: _MessageV | _T) -> _MessageV | _T: ... + def get(self, key: _K, default: _MessageV) -> _MessageV: ... + @overload + def get(self, key: _K, default: _T) -> _MessageV | _T: ... def get_or_create(self, key: _K) -> _MessageV: ... def setdefault(self, key: _K, value: _MessageV | None = None) -> _MessageV: ... def MergeFrom(self, other: Self): ...