From 240e3d800418388e79ea7c7836d33177ed6d2e54 Mon Sep 17 00:00:00 2001 From: Max Muoto Date: Sat, 20 Jul 2024 15:37:30 -0500 Subject: [PATCH] Correct `MappingProxyType` for 3.8+ (#12369) Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com> --- stdlib/@tests/stubtest_allowlists/py313.txt | 4 +++- stdlib/@tests/test_cases/check_types.py | 14 ++++++++++++++ stdlib/types.pyi | 4 ++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/stdlib/@tests/stubtest_allowlists/py313.txt b/stdlib/@tests/stubtest_allowlists/py313.txt index 7d6b8870f..73b0b35fa 100644 --- a/stdlib/@tests/stubtest_allowlists/py313.txt +++ b/stdlib/@tests/stubtest_allowlists/py313.txt @@ -50,7 +50,6 @@ tkinter.Misc.tk_busy_hold tkinter.Misc.tk_busy_status tkinter.Text.count tkinter.Wm.wm_attributes -types.MappingProxyType.get # ====================================== # Pre-existing errors from Python <=3.12 @@ -199,3 +198,6 @@ codecs.namereplace_errors codecs.replace_errors codecs.strict_errors codecs.xmlcharrefreplace_errors + +# To match `dict`, we lie about the runtime, but use overloads to match the correct behavior +types.MappingProxyType.get diff --git a/stdlib/@tests/test_cases/check_types.py b/stdlib/@tests/test_cases/check_types.py index 3654ca864..7dcf31923 100644 --- a/stdlib/@tests/test_cases/check_types.py +++ b/stdlib/@tests/test_cases/check_types.py @@ -1,6 +1,8 @@ import sys import types from collections import UserDict +from typing import Union +from typing_extensions import assert_type # test `types.SimpleNamespace` @@ -25,3 +27,15 @@ types.SimpleNamespace({1: 2}) # type: ignore types.SimpleNamespace([[1, 2]]) # type: ignore types.SimpleNamespace(UserDict({1: 2})) # type: ignore types.SimpleNamespace([[[], 2]]) # type: ignore + +# test: `types.MappingProxyType` +mp = types.MappingProxyType({1: 2, 3: 4}) +mp.get("x") # type: ignore +item = mp.get(1) +assert_type(item, Union[int, None]) +item_2 = mp.get(2, 0) +assert_type(item_2, int) +item_3 = mp.get(3, "default") +assert_type(item_3, Union[int, str]) +# Default isn't accepted as a keyword argument. +mp.get(4, default="default") # type: ignore diff --git a/stdlib/types.pyi b/stdlib/types.pyi index a569b55ef..1e3eacd9f 100644 --- a/stdlib/types.pyi +++ b/stdlib/types.pyi @@ -304,6 +304,10 @@ class MappingProxyType(Mapping[_KT, _VT_co]): def keys(self) -> KeysView[_KT]: ... def values(self) -> ValuesView[_VT_co]: ... def items(self) -> ItemsView[_KT, _VT_co]: ... + @overload + def get(self, key: _KT, /) -> _VT_co | None: ... # type: ignore[override] + @overload + def get(self, key: _KT, default: _VT_co | _T2, /) -> _VT_co | _T2: ... # type: ignore[override] if sys.version_info >= (3, 9): def __class_getitem__(cls, item: Any, /) -> GenericAlias: ... def __reversed__(self) -> Iterator[_KT]: ...