From 6eec191739f7f39752ae1e95a7d7cf18fca4c12b Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Fri, 8 Sep 2023 11:22:17 +0100 Subject: [PATCH] Improve the accuracy of `(default)dict.__(r)or__` (#10679) --- stdlib/builtins.pyi | 8 +-- stdlib/collections/__init__.pyi | 8 +-- test_cases/stdlib/builtins/check_dict-py39.py | 66 ++++++++++++++++++ .../collections/check_defaultdict-py39.py | 69 +++++++++++++++++++ 4 files changed, 143 insertions(+), 8 deletions(-) create mode 100644 test_cases/stdlib/builtins/check_dict-py39.py create mode 100644 test_cases/stdlib/collections/check_defaultdict-py39.py diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index 71cccee16..cf4f857c5 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -1120,13 +1120,13 @@ class dict(MutableMapping[_KT, _VT], Generic[_KT, _VT]): if sys.version_info >= (3, 9): def __class_getitem__(cls, __item: Any) -> GenericAlias: ... @overload - def __or__(self, __value: Mapping[_KT, _VT]) -> dict[_KT, _VT]: ... + def __or__(self, __value: dict[_KT, _VT]) -> dict[_KT, _VT]: ... @overload - def __or__(self, __value: Mapping[_T1, _T2]) -> dict[_KT | _T1, _VT | _T2]: ... + def __or__(self, __value: dict[_T1, _T2]) -> dict[_KT | _T1, _VT | _T2]: ... @overload - def __ror__(self, __value: Mapping[_KT, _VT]) -> dict[_KT, _VT]: ... + def __ror__(self, __value: dict[_KT, _VT]) -> dict[_KT, _VT]: ... @overload - def __ror__(self, __value: Mapping[_T1, _T2]) -> dict[_KT | _T1, _VT | _T2]: ... + def __ror__(self, __value: dict[_T1, _T2]) -> dict[_KT | _T1, _VT | _T2]: ... # dict.__ior__ should be kept roughly in line with MutableMapping.update() @overload # type: ignore[misc] def __ior__(self, __value: SupportsKeysAndGetItem[_KT, _VT]) -> Self: ... diff --git a/stdlib/collections/__init__.pyi b/stdlib/collections/__init__.pyi index 8ceecd1f3..9c4dbb72f 100644 --- a/stdlib/collections/__init__.pyi +++ b/stdlib/collections/__init__.pyi @@ -402,13 +402,13 @@ class defaultdict(dict[_KT, _VT], Generic[_KT, _VT]): def copy(self) -> Self: ... if sys.version_info >= (3, 9): @overload - def __or__(self, __value: Mapping[_KT, _VT]) -> Self: ... + def __or__(self, __value: dict[_KT, _VT]) -> Self: ... @overload - def __or__(self, __value: Mapping[_T1, _T2]) -> defaultdict[_KT | _T1, _VT | _T2]: ... + def __or__(self, __value: dict[_T1, _T2]) -> defaultdict[_KT | _T1, _VT | _T2]: ... @overload - def __ror__(self, __value: Mapping[_KT, _VT]) -> Self: ... + def __ror__(self, __value: dict[_KT, _VT]) -> Self: ... @overload - def __ror__(self, __value: Mapping[_T1, _T2]) -> defaultdict[_KT | _T1, _VT | _T2]: ... + def __ror__(self, __value: dict[_T1, _T2]) -> defaultdict[_KT | _T1, _VT | _T2]: ... # type: ignore[misc] class ChainMap(MutableMapping[_KT, _VT], Generic[_KT, _VT]): maps: list[MutableMapping[_KT, _VT]] diff --git a/test_cases/stdlib/builtins/check_dict-py39.py b/test_cases/stdlib/builtins/check_dict-py39.py new file mode 100644 index 000000000..29e8f4b56 --- /dev/null +++ b/test_cases/stdlib/builtins/check_dict-py39.py @@ -0,0 +1,66 @@ +""" +Tests for `dict.__(r)or__`. + +`dict.__or__` and `dict.__ror__` were only added in py39, +hence why these are in a separate file to the other test cases for `dict`. +""" +from __future__ import annotations + +import os +import sys +from typing import Mapping, TypeVar, Union +from typing_extensions import Self, assert_type + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + +if sys.version_info >= (3, 9): + + class CustomDictSubclass(dict[_KT, _VT]): + pass + + class CustomMappingWithDunderOr(Mapping[_KT, _VT]): + def __or__(self, other: Mapping[_KT, _VT]) -> dict[_KT, _VT]: + return {} + + def __ror__(self, other: Mapping[_KT, _VT]) -> dict[_KT, _VT]: + return {} + + def __ior__(self, other: Mapping[_KT, _VT]) -> Self: + return self + + def test_dict_dot_or( + a: dict[int, int], + b: CustomDictSubclass[int, int], + c: dict[str, str], + d: Mapping[int, int], + e: CustomMappingWithDunderOr[str, str], + ) -> None: + # dict.__(r)or__ always returns a dict, even if called on a subclass of dict: + assert_type(a | b, dict[int, int]) + assert_type(b | a, dict[int, int]) + + assert_type(a | c, dict[Union[int, str], Union[int, str]]) + + # arbitrary mappings are not accepted by `dict.__or__`; + # it has to be a subclass of `dict` + a | d # type: ignore + + # but Mappings such as `os._Environ` or `CustomMappingWithDunderOr`, + # which define `__ror__` methods that accept `dict`, are fine: + assert_type(a | os.environ, dict[Union[str, int], Union[str, int]]) + assert_type(os.environ | a, dict[Union[str, int], Union[str, int]]) + + assert_type(c | os.environ, dict[str, str]) + assert_type(c | e, dict[str, str]) + + assert_type(os.environ | c, dict[str, str]) + assert_type(e | c, dict[str, str]) + + e |= c + e |= a # type: ignore + + # TODO: this test passes mypy, but fails pyright for some reason: + # c |= e + + c |= a # type: ignore diff --git a/test_cases/stdlib/collections/check_defaultdict-py39.py b/test_cases/stdlib/collections/check_defaultdict-py39.py new file mode 100644 index 000000000..9fe5ec807 --- /dev/null +++ b/test_cases/stdlib/collections/check_defaultdict-py39.py @@ -0,0 +1,69 @@ +""" +Tests for `defaultdict.__or__` and `defaultdict.__ror__`. +These methods were only added in py39. +""" + +from __future__ import annotations + +import os +import sys +from collections import defaultdict +from typing import Mapping, TypeVar, Union +from typing_extensions import Self, assert_type + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +if sys.version_info >= (3, 9): + + class CustomDefaultDictSubclass(defaultdict[_KT, _VT]): + pass + + class CustomMappingWithDunderOr(Mapping[_KT, _VT]): + def __or__(self, other: Mapping[_KT, _VT]) -> dict[_KT, _VT]: + return {} + + def __ror__(self, other: Mapping[_KT, _VT]) -> dict[_KT, _VT]: + return {} + + def __ior__(self, other: Mapping[_KT, _VT]) -> Self: + return self + + def test_defaultdict_dot_or( + a: defaultdict[int, int], + b: CustomDefaultDictSubclass[int, int], + c: defaultdict[str, str], + d: Mapping[int, int], + e: CustomMappingWithDunderOr[str, str], + ) -> None: + assert_type(a | b, defaultdict[int, int]) + + # In contrast to `dict.__or__`, `defaultdict.__or__` returns `Self` if called on a subclass of `defaultdict`: + assert_type(b | a, CustomDefaultDictSubclass[int, int]) + + assert_type(a | c, defaultdict[Union[int, str], Union[int, str]]) + + # arbitrary mappings are not accepted by `defaultdict.__or__`; + # it has to be a subclass of `dict` + a | d # type: ignore + + # but Mappings such as `os._Environ` or `CustomMappingWithDunderOr`, + # which define `__ror__` methods that accept `dict`, are fine + # (`os._Environ.__(r)or__` always returns `dict`, even if a `defaultdict` is passed): + assert_type(a | os.environ, dict[Union[str, int], Union[str, int]]) + assert_type(os.environ | a, dict[Union[str, int], Union[str, int]]) + + assert_type(c | os.environ, dict[str, str]) + assert_type(c | e, dict[str, str]) + + assert_type(os.environ | c, dict[str, str]) + assert_type(e | c, dict[str, str]) + + e |= c + e |= a # type: ignore + + # TODO: this test passes mypy, but fails pyright for some reason: + # c |= e + + c |= a # type: ignore