Move test_cases to stdlib/@tests/test_cases (#11865)

This commit is contained in:
Sebastian Rittau
2024-05-10 04:27:09 +02:00
committed by GitHub
parent ea61ca5a30
commit 392ae934fc
49 changed files with 37 additions and 37 deletions

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
from asyncio import iscoroutinefunction
from collections.abc import Awaitable, Callable, Coroutine
from typing import Any
from typing_extensions import assert_type
def test_iscoroutinefunction(
x: Callable[[str, int], Coroutine[str, int, bytes]],
y: Callable[[str, int], Awaitable[bytes]],
z: Callable[[str, int], str | Awaitable[bytes]],
xx: object,
) -> None:
if iscoroutinefunction(x):
assert_type(x, Callable[[str, int], Coroutine[str, int, bytes]])
if iscoroutinefunction(y):
assert_type(y, Callable[[str, int], Coroutine[Any, Any, bytes]])
if iscoroutinefunction(z):
assert_type(z, Callable[[str, int], Coroutine[Any, Any, Any]])
if iscoroutinefunction(xx):
assert_type(xx, Callable[..., Coroutine[Any, Any, Any]])

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
import asyncio
from typing import Awaitable, List, Tuple, Union
from typing_extensions import assert_type
async def coro1() -> int:
return 42
async def coro2() -> str:
return "spam"
async def test_gather(awaitable1: Awaitable[int], awaitable2: Awaitable[str]) -> None:
a = await asyncio.gather(awaitable1)
assert_type(a, Tuple[int])
b = await asyncio.gather(awaitable1, awaitable2, return_exceptions=True)
assert_type(b, Tuple[Union[int, BaseException], Union[str, BaseException]])
c = await asyncio.gather(awaitable1, awaitable2, awaitable1, awaitable1, awaitable1, awaitable1)
assert_type(c, Tuple[int, str, int, int, int, int])
d = await asyncio.gather(awaitable1, awaitable1, awaitable1, awaitable1, awaitable1, awaitable1, awaitable1)
assert_type(d, List[int])
awaitables_list: list[Awaitable[int]] = [awaitable1]
e = await asyncio.gather(*awaitables_list)
assert_type(e, List[int])
# this case isn't reliable between typecheckers, no one would ever call it with no args anyway
# f = await asyncio.gather()
# assert_type(f, list[Any])
asyncio.run(test_gather(coro1(), coro2()))

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
import asyncio
class Waiter:
def __init__(self) -> None:
self.tasks: list[asyncio.Task[object]] = []
def add(self, t: asyncio.Task[object]) -> None:
self.tasks.append(t)
async def join(self) -> None:
await asyncio.wait(self.tasks)
async def foo() -> int:
return 42
async def main() -> None:
# asyncio.Task is covariant in its type argument, which is unusual since its parent class
# asyncio.Future is invariant in its type argument. This is only sound because asyncio.Task
# is not actually Liskov substitutable for asyncio.Future: it does not implement set_result.
w = Waiter()
t: asyncio.Task[int] = asyncio.create_task(foo())
w.add(t)
await w.join()

View File

@@ -0,0 +1,67 @@
"""
Tests for `dict.__(r)or__`.
`dict.__or__` and `dict.__ror__` were only added in py39,
hence why these are in a separate file to the other test cases for `dict`.
"""
from __future__ import annotations
import os
import sys
from typing import Mapping, TypeVar, Union
from typing_extensions import Self, assert_type
_KT = TypeVar("_KT")
_VT = TypeVar("_VT")
if sys.version_info >= (3, 9):
class CustomDictSubclass(dict[_KT, _VT]):
pass
class CustomMappingWithDunderOr(Mapping[_KT, _VT]):
def __or__(self, other: Mapping[_KT, _VT]) -> dict[_KT, _VT]:
return {}
def __ror__(self, other: Mapping[_KT, _VT]) -> dict[_KT, _VT]:
return {}
def __ior__(self, other: Mapping[_KT, _VT]) -> Self:
return self
def test_dict_dot_or(
a: dict[int, int],
b: CustomDictSubclass[int, int],
c: dict[str, str],
d: Mapping[int, int],
e: CustomMappingWithDunderOr[str, str],
) -> None:
# dict.__(r)or__ always returns a dict, even if called on a subclass of dict:
assert_type(a | b, dict[int, int])
assert_type(b | a, dict[int, int])
assert_type(a | c, dict[Union[int, str], Union[int, str]])
# arbitrary mappings are not accepted by `dict.__or__`;
# it has to be a subclass of `dict`
a | d # type: ignore
# but Mappings such as `os._Environ` or `CustomMappingWithDunderOr`,
# which define `__ror__` methods that accept `dict`, are fine:
assert_type(a | os.environ, dict[Union[str, int], Union[str, int]])
assert_type(os.environ | a, dict[Union[str, int], Union[str, int]])
assert_type(c | os.environ, dict[str, str])
assert_type(c | e, dict[str, str])
assert_type(os.environ | c, dict[str, str])
assert_type(e | c, dict[str, str])
e |= c
e |= a # type: ignore
# TODO: this test passes mypy, but fails pyright for some reason:
# c |= e
c |= a # type: ignore

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from typing import Dict, Generic, Iterable, TypeVar
from typing_extensions import assert_type
# These do follow `__init__` overloads order:
# mypy and pyright have different opinions about this one:
# mypy raises: 'Need type annotation for "bad"'
# pyright is fine with it.
# bad = dict()
good: dict[str, str] = dict()
assert_type(good, Dict[str, str])
assert_type(dict(arg=1), Dict[str, int])
_KT = TypeVar("_KT")
_VT = TypeVar("_VT")
class KeysAndGetItem(Generic[_KT, _VT]):
data: dict[_KT, _VT]
def __init__(self, data: dict[_KT, _VT]) -> None:
self.data = data
def keys(self) -> Iterable[_KT]:
return self.data.keys()
def __getitem__(self, __k: _KT) -> _VT:
return self.data[__k]
kt1: KeysAndGetItem[int, str] = KeysAndGetItem({0: ""})
assert_type(dict(kt1), Dict[int, str])
dict(kt1, arg="a") # type: ignore
kt2: KeysAndGetItem[str, int] = KeysAndGetItem({"": 0})
assert_type(dict(kt2, arg=1), Dict[str, int])
def test_iterable_tuple_overload(x: Iterable[tuple[int, str]]) -> dict[int, str]:
return dict(x)
i1: Iterable[tuple[int, str]] = [(1, "a"), (2, "b")]
test_iterable_tuple_overload(i1)
dict(i1, arg="a") # type: ignore
i2: Iterable[tuple[str, int]] = [("a", 1), ("b", 2)]
assert_type(dict(i2, arg=1), Dict[str, int])
i3: Iterable[str] = ["a.b"]
i4: Iterable[bytes] = [b"a.b"]
assert_type(dict(string.split(".") for string in i3), Dict[str, str])
assert_type(dict(string.split(b".") for string in i4), Dict[bytes, bytes])
dict(["foo", "bar", "baz"]) # type: ignore
dict([b"foo", b"bar", b"baz"]) # type: ignore

View File

@@ -0,0 +1,323 @@
from __future__ import annotations
import sys
from typing import TypeVar
from typing_extensions import assert_type
if sys.version_info >= (3, 11):
# This can be removed later, but right now Flake8 does not know
# about these two classes:
from builtins import BaseExceptionGroup, ExceptionGroup
# BaseExceptionGroup
# ==================
# `BaseExceptionGroup` can work with `BaseException`:
beg = BaseExceptionGroup("x", [SystemExit(), SystemExit()])
assert_type(beg, BaseExceptionGroup[SystemExit])
assert_type(beg.exceptions, tuple[SystemExit | BaseExceptionGroup[SystemExit], ...])
# Covariance works:
_beg1: BaseExceptionGroup[BaseException] = beg
# `BaseExceptionGroup` can work with `Exception`:
beg2 = BaseExceptionGroup("x", [ValueError()])
# FIXME: this is not right, runtime returns `ExceptionGroup` instance instead,
# but I am unable to represent this with types right now.
assert_type(beg2, BaseExceptionGroup[ValueError])
# .subgroup()
# -----------
assert_type(beg.subgroup(KeyboardInterrupt), BaseExceptionGroup[KeyboardInterrupt] | None)
assert_type(beg.subgroup((KeyboardInterrupt,)), BaseExceptionGroup[KeyboardInterrupt] | None)
def is_base_exc(exc: BaseException) -> bool:
return isinstance(exc, BaseException)
def is_specific(exc: SystemExit | BaseExceptionGroup[SystemExit]) -> bool:
return isinstance(exc, SystemExit)
# This one does not have `BaseExceptionGroup` part,
# this is why we treat as an error.
def is_system_exit(exc: SystemExit) -> bool:
return isinstance(exc, SystemExit)
def unrelated_subgroup(exc: KeyboardInterrupt) -> bool:
return False
assert_type(beg.subgroup(is_base_exc), BaseExceptionGroup[SystemExit] | None)
assert_type(beg.subgroup(is_specific), BaseExceptionGroup[SystemExit] | None)
beg.subgroup(is_system_exit) # type: ignore
beg.subgroup(unrelated_subgroup) # type: ignore
# `Exception`` subgroup returns `ExceptionGroup`:
assert_type(beg.subgroup(ValueError), ExceptionGroup[ValueError] | None)
assert_type(beg.subgroup((ValueError,)), ExceptionGroup[ValueError] | None)
# Callable are harder, we don't support cast to `ExceptionGroup` here.
# Because callables might return `True` the first time. And `BaseExceptionGroup`
# will stick, no matter what arguments are.
def is_exception(exc: Exception) -> bool:
return isinstance(exc, Exception)
def is_exception_or_beg(exc: Exception | BaseExceptionGroup[SystemExit]) -> bool:
return isinstance(exc, Exception)
# This is an error because of the `Exception` argument type,
# while `SystemExit` is needed instead.
beg.subgroup(is_exception_or_beg) # type: ignore
# This is an error, because `BaseExceptionGroup` is not an `Exception`
# subclass. It is required.
beg.subgroup(is_exception) # type: ignore
# .split()
# --------
assert_type(
beg.split(KeyboardInterrupt), tuple[BaseExceptionGroup[KeyboardInterrupt] | None, BaseExceptionGroup[SystemExit] | None]
)
assert_type(
beg.split((KeyboardInterrupt,)),
tuple[BaseExceptionGroup[KeyboardInterrupt] | None, BaseExceptionGroup[SystemExit] | None],
)
assert_type(
beg.split(ValueError), # there are no `ValueError` items in there, but anyway
tuple[ExceptionGroup[ValueError] | None, BaseExceptionGroup[SystemExit] | None],
)
excs_to_split: list[ValueError | KeyError | SystemExit] = [ValueError(), KeyError(), SystemExit()]
to_split = BaseExceptionGroup("x", excs_to_split)
assert_type(to_split, BaseExceptionGroup[ValueError | KeyError | SystemExit])
# Ideally the first part should be `ExceptionGroup[ValueError]` (done)
# and the second part should be `BaseExceptionGroup[KeyError | SystemExit]`,
# but we cannot subtract type from a union.
# We also cannot change `BaseExceptionGroup` to `ExceptionGroup` even if needed
# in the second part here because of that.
assert_type(
to_split.split(ValueError),
tuple[ExceptionGroup[ValueError] | None, BaseExceptionGroup[ValueError | KeyError | SystemExit] | None],
)
def split_callable1(exc: ValueError | KeyError | SystemExit | BaseExceptionGroup[ValueError | KeyError | SystemExit]) -> bool:
return True
assert_type(
to_split.split(split_callable1), # Concrete type is ok
tuple[
BaseExceptionGroup[ValueError | KeyError | SystemExit] | None,
BaseExceptionGroup[ValueError | KeyError | SystemExit] | None,
],
)
assert_type(
to_split.split(is_base_exc), # Base class is ok
tuple[
BaseExceptionGroup[ValueError | KeyError | SystemExit] | None,
BaseExceptionGroup[ValueError | KeyError | SystemExit] | None,
],
)
# `Exception` cannot be used: `BaseExceptionGroup` is not a subtype of it.
to_split.split(is_exception) # type: ignore
# .derive()
# ---------
assert_type(beg.derive([ValueError()]), ExceptionGroup[ValueError])
assert_type(beg.derive([KeyboardInterrupt()]), BaseExceptionGroup[KeyboardInterrupt])
# ExceptionGroup
# ==============
# `ExceptionGroup` can work with `Exception`:
excs: list[ValueError | KeyError] = [ValueError(), KeyError()]
eg = ExceptionGroup("x", excs)
assert_type(eg, ExceptionGroup[ValueError | KeyError])
assert_type(eg.exceptions, tuple[ValueError | KeyError | ExceptionGroup[ValueError | KeyError], ...])
# Covariance works:
_eg1: ExceptionGroup[Exception] = eg
# `ExceptionGroup` cannot work with `BaseException`:
ExceptionGroup("x", [SystemExit()]) # type: ignore
# .subgroup()
# -----------
# Our decision is to ban cases like::
#
# >>> eg = ExceptionGroup('x', [ValueError()])
# >>> eg.subgroup(BaseException)
# ExceptionGroup('e', [ValueError()])
#
# are possible in runtime.
# We do it because, it does not make sense for all other base exception types.
# Supporting just `BaseException` looks like an overkill.
eg.subgroup(BaseException) # type: ignore
eg.subgroup((KeyboardInterrupt, SystemExit)) # type: ignore
assert_type(eg.subgroup(Exception), ExceptionGroup[Exception] | None)
assert_type(eg.subgroup(ValueError), ExceptionGroup[ValueError] | None)
assert_type(eg.subgroup((ValueError,)), ExceptionGroup[ValueError] | None)
def subgroup_eg1(exc: ValueError | KeyError | ExceptionGroup[ValueError | KeyError]) -> bool:
return True
def subgroup_eg2(exc: ValueError | KeyError) -> bool:
return True
assert_type(eg.subgroup(subgroup_eg1), ExceptionGroup[ValueError | KeyError] | None)
assert_type(eg.subgroup(is_exception), ExceptionGroup[ValueError | KeyError] | None)
assert_type(eg.subgroup(is_base_exc), ExceptionGroup[ValueError | KeyError] | None)
assert_type(eg.subgroup(is_base_exc), ExceptionGroup[ValueError | KeyError] | None)
# Does not have `ExceptionGroup` part:
eg.subgroup(subgroup_eg2) # type: ignore
# .split()
# --------
assert_type(eg.split(TypeError), tuple[ExceptionGroup[TypeError] | None, ExceptionGroup[ValueError | KeyError] | None])
assert_type(eg.split((TypeError,)), tuple[ExceptionGroup[TypeError] | None, ExceptionGroup[ValueError | KeyError] | None])
assert_type(
eg.split(is_exception), tuple[ExceptionGroup[ValueError | KeyError] | None, ExceptionGroup[ValueError | KeyError] | None]
)
assert_type(
eg.split(is_base_exc),
# is not converted, because `ExceptionGroup` cannot have
# direct `BaseException` subclasses inside.
tuple[ExceptionGroup[ValueError | KeyError] | None, ExceptionGroup[ValueError | KeyError] | None],
)
# It does not include `ExceptionGroup` itself, so it will fail:
def value_or_key_error(exc: ValueError | KeyError) -> bool:
return isinstance(exc, (ValueError, KeyError))
eg.split(value_or_key_error) # type: ignore
# `ExceptionGroup` cannot have direct `BaseException` subclasses inside.
eg.split(BaseException) # type: ignore
eg.split((SystemExit, GeneratorExit)) # type: ignore
# .derive()
# ---------
assert_type(eg.derive([ValueError()]), ExceptionGroup[ValueError])
assert_type(eg.derive([KeyboardInterrupt()]), BaseExceptionGroup[KeyboardInterrupt])
# BaseExceptionGroup Custom Subclass
# ==================================
# In some cases `Self` type can be preserved in runtime,
# but it is impossible to express. That's why we always fallback to
# `BaseExceptionGroup` and `ExceptionGroup`.
_BE = TypeVar("_BE", bound=BaseException)
class CustomBaseGroup(BaseExceptionGroup[_BE]): ...
cb1 = CustomBaseGroup("x", [SystemExit()])
assert_type(cb1, CustomBaseGroup[SystemExit])
cb2 = CustomBaseGroup("x", [ValueError()])
assert_type(cb2, CustomBaseGroup[ValueError])
# .subgroup()
# -----------
assert_type(cb1.subgroup(KeyboardInterrupt), BaseExceptionGroup[KeyboardInterrupt] | None)
assert_type(cb2.subgroup((KeyboardInterrupt,)), BaseExceptionGroup[KeyboardInterrupt] | None)
assert_type(cb1.subgroup(ValueError), ExceptionGroup[ValueError] | None)
assert_type(cb2.subgroup((KeyError,)), ExceptionGroup[KeyError] | None)
def cb_subgroup1(exc: SystemExit | CustomBaseGroup[SystemExit]) -> bool:
return True
def cb_subgroup2(exc: ValueError | CustomBaseGroup[ValueError]) -> bool:
return True
assert_type(cb1.subgroup(cb_subgroup1), BaseExceptionGroup[SystemExit] | None)
assert_type(cb2.subgroup(cb_subgroup2), BaseExceptionGroup[ValueError] | None)
cb1.subgroup(cb_subgroup2) # type: ignore
cb2.subgroup(cb_subgroup1) # type: ignore
# .split()
# --------
assert_type(
cb1.split(KeyboardInterrupt), tuple[BaseExceptionGroup[KeyboardInterrupt] | None, BaseExceptionGroup[SystemExit] | None]
)
assert_type(cb1.split(TypeError), tuple[ExceptionGroup[TypeError] | None, BaseExceptionGroup[SystemExit] | None])
assert_type(cb2.split((TypeError,)), tuple[ExceptionGroup[TypeError] | None, BaseExceptionGroup[ValueError] | None])
def cb_split1(exc: SystemExit | CustomBaseGroup[SystemExit]) -> bool:
return True
def cb_split2(exc: ValueError | CustomBaseGroup[ValueError]) -> bool:
return True
assert_type(cb1.split(cb_split1), tuple[BaseExceptionGroup[SystemExit] | None, BaseExceptionGroup[SystemExit] | None])
assert_type(cb2.split(cb_split2), tuple[BaseExceptionGroup[ValueError] | None, BaseExceptionGroup[ValueError] | None])
cb1.split(cb_split2) # type: ignore
cb2.split(cb_split1) # type: ignore
# .derive()
# ---------
# Note, that `Self` type is not preserved in runtime.
assert_type(cb1.derive([ValueError()]), ExceptionGroup[ValueError])
assert_type(cb1.derive([KeyboardInterrupt()]), BaseExceptionGroup[KeyboardInterrupt])
assert_type(cb2.derive([ValueError()]), ExceptionGroup[ValueError])
assert_type(cb2.derive([KeyboardInterrupt()]), BaseExceptionGroup[KeyboardInterrupt])
# ExceptionGroup Custom Subclass
# ==============================
_E = TypeVar("_E", bound=Exception)
class CustomGroup(ExceptionGroup[_E]): ...
CustomGroup("x", [SystemExit()]) # type: ignore
cg1 = CustomGroup("x", [ValueError()])
assert_type(cg1, CustomGroup[ValueError])
# .subgroup()
# -----------
cg1.subgroup(BaseException) # type: ignore
cg1.subgroup((KeyboardInterrupt, SystemExit)) # type: ignore
assert_type(cg1.subgroup(ValueError), ExceptionGroup[ValueError] | None)
assert_type(cg1.subgroup((KeyError,)), ExceptionGroup[KeyError] | None)
def cg_subgroup1(exc: ValueError | CustomGroup[ValueError]) -> bool:
return True
def cg_subgroup2(exc: ValueError) -> bool:
return True
assert_type(cg1.subgroup(cg_subgroup1), ExceptionGroup[ValueError] | None)
cg1.subgroup(cb_subgroup2) # type: ignore
# .split()
# --------
assert_type(cg1.split(TypeError), tuple[ExceptionGroup[TypeError] | None, ExceptionGroup[ValueError] | None])
assert_type(cg1.split((TypeError,)), tuple[ExceptionGroup[TypeError] | None, ExceptionGroup[ValueError] | None])
cg1.split(BaseException) # type: ignore
def cg_split1(exc: ValueError | CustomGroup[ValueError]) -> bool:
return True
def cg_split2(exc: ValueError) -> bool:
return True
assert_type(cg1.split(cg_split1), tuple[ExceptionGroup[ValueError] | None, ExceptionGroup[ValueError] | None])
cg1.split(cg_split2) # type: ignore
# .derive()
# ---------
# Note, that `Self` type is not preserved in runtime.
assert_type(cg1.derive([ValueError()]), ExceptionGroup[ValueError])
assert_type(cg1.derive([KeyboardInterrupt()]), BaseExceptionGroup[KeyboardInterrupt])

View File

@@ -0,0 +1,16 @@
from __future__ import annotations
from typing import Iterator
from typing_extensions import assert_type
class OldStyleIter:
def __getitem__(self, index: int) -> str:
return str(index)
for x in iter(OldStyleIter()):
assert_type(x, str)
assert_type(iter(OldStyleIter()), Iterator[str])
assert_type(next(iter(OldStyleIter())), str)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from typing import List, Union
from typing_extensions import assert_type
# list.__add__ example from #8292
class Foo:
def asd(self) -> int:
return 1
class Bar:
def asd(self) -> int:
return 2
combined = [Foo()] + [Bar()]
assert_type(combined, List[Union[Foo, Bar]])
for item in combined:
assert_type(item.asd(), int)

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
from typing import Any
# The following should pass without error (see #6661):
class Diagnostic:
def __reduce__(self) -> str | tuple[Any, ...]:
res = super().__reduce__()
if isinstance(res, tuple) and len(res) >= 3:
res[2]["_info"] = 42
return res

View File

@@ -0,0 +1,91 @@
from __future__ import annotations
from decimal import Decimal
from fractions import Fraction
from typing import Any, Literal
from typing_extensions import assert_type
# See #7163
assert_type(pow(1, 0), Literal[1])
assert_type(1**0, Literal[1])
assert_type(pow(1, 0, None), Literal[1])
# TODO: We don't have a good way of expressing the fact
# that passing 0 for the third argument will lead to an exception being raised
# (see discussion in #8566)
#
# assert_type(pow(2, 4, 0), NoReturn)
assert_type(pow(2, 4), int)
assert_type(2**4, int)
assert_type(pow(4, 6, None), int)
assert_type(pow(5, -7), float)
assert_type(5**-7, float)
assert_type(pow(2, 4, 5), int) # pow(<smallint>, <smallint>, <smallint>)
assert_type(pow(2, 35, 3), int) # pow(<smallint>, <bigint>, <smallint>)
assert_type(pow(2, 8.5), float)
assert_type(2**8.6, float)
assert_type(pow(2, 8.6, None), float)
# TODO: Why does this pass pyright but not mypy??
# assert_type((-2) ** 0.5, complex)
assert_type(pow((-5), 8.42, None), complex)
assert_type(pow(4.6, 8), float)
assert_type(4.6**8, float)
assert_type(pow(5.1, 4, None), float)
assert_type(pow(complex(6), 6.2), complex)
assert_type(complex(6) ** 6.2, complex)
assert_type(pow(complex(9), 7.3, None), complex)
assert_type(pow(Fraction(), 4, None), Fraction)
assert_type(Fraction() ** 4, Fraction)
assert_type(pow(Fraction(3, 7), complex(1, 8)), complex)
assert_type(Fraction(3, 7) ** complex(1, 8), complex)
assert_type(pow(complex(4, -8), Fraction(2, 3)), complex)
assert_type(complex(4, -8) ** Fraction(2, 3), complex)
assert_type(pow(Decimal("1.0"), Decimal("1.6")), Decimal)
assert_type(Decimal("1.0") ** Decimal("1.6"), Decimal)
assert_type(pow(Decimal("1.0"), Decimal("1.0"), Decimal("1.0")), Decimal)
assert_type(pow(Decimal("4.6"), 7, None), Decimal)
assert_type(Decimal("4.6") ** 7, Decimal)
# These would ideally be more precise, but `Any` is acceptable
# They have to be `Any` due to the fact that type-checkers can't distinguish
# between positive and negative numbers for the second argument to `pow()`
#
# int for positive 2nd-arg, float otherwise
assert_type(pow(4, 65), Any)
assert_type(pow(2, -45), Any)
assert_type(pow(3, 57, None), Any)
assert_type(pow(67, 0.98, None), Any)
assert_type(87**7.32, Any)
# pow(<pos-float>, <pos-or-neg-float>) -> float
# pow(<neg-float>, <pos-or-neg-float>) -> complex
assert_type(pow(4.7, 7.4), Any)
assert_type(pow(-9.8, 8.3), Any)
assert_type(pow(-9.3, -88.2), Any)
assert_type(pow(8.2, -9.8), Any)
assert_type(pow(4.7, 9.2, None), Any)
# See #7046 -- float for a positive 1st arg, complex otherwise
assert_type((-95) ** 8.42, Any)
# All of the following cases should fail a type-checker.
pow(1.9, 4, 6) # type: ignore
pow(4, 7, 4.32) # type: ignore
pow(6.2, 5.9, 73) # type: ignore
pow(complex(6), 6.2, 7) # type: ignore
pow(Fraction(), 5, 8) # type: ignore
Decimal("8.7") ** 3.14 # type: ignore
# TODO: This fails at runtime, but currently passes mypy and pyright:
pow(Decimal("8.5"), 3.21)

View File

@@ -0,0 +1,34 @@
from __future__ import annotations
from collections.abc import Iterator
from typing import Generic, TypeVar
from typing_extensions import assert_type
x: list[int] = []
assert_type(list(reversed(x)), "list[int]")
class MyReversible:
def __iter__(self) -> Iterator[str]:
yield "blah"
def __reversed__(self) -> Iterator[str]:
yield "blah"
assert_type(list(reversed(MyReversible())), "list[str]")
_T = TypeVar("_T")
class MyLenAndGetItem(Generic[_T]):
def __len__(self) -> int:
return 0
def __getitem__(self, item: int) -> _T:
raise KeyError
len_and_get_item: MyLenAndGetItem[int] = MyLenAndGetItem()
assert_type(list(reversed(len_and_get_item)), "list[int]")

View File

@@ -0,0 +1,68 @@
from __future__ import annotations
from typing import overload
from typing_extensions import assert_type
class CustomIndex:
def __index__(self) -> int:
return 1
# float:
assert_type(round(5.5), int)
assert_type(round(5.5, None), int)
assert_type(round(5.5, 0), float)
assert_type(round(5.5, 1), float)
assert_type(round(5.5, 5), float)
assert_type(round(5.5, CustomIndex()), float)
# int:
assert_type(round(1), int)
assert_type(round(1, 1), int)
assert_type(round(1, None), int)
assert_type(round(1, CustomIndex()), int)
# Protocols:
class WithCustomRound1:
def __round__(self) -> str:
return "a"
assert_type(round(WithCustomRound1()), str)
assert_type(round(WithCustomRound1(), None), str)
# Errors:
round(WithCustomRound1(), 1) # type: ignore
round(WithCustomRound1(), CustomIndex()) # type: ignore
class WithCustomRound2:
def __round__(self, digits: int) -> str:
return "a"
assert_type(round(WithCustomRound2(), 1), str)
assert_type(round(WithCustomRound2(), CustomIndex()), str)
# Errors:
round(WithCustomRound2(), None) # type: ignore
round(WithCustomRound2()) # type: ignore
class WithOverloadedRound:
@overload
def __round__(self, ndigits: None = ...) -> str: ...
@overload
def __round__(self, ndigits: int) -> bytes: ...
def __round__(self, ndigits: int | None = None) -> str | bytes:
return b"" if ndigits is None else ""
assert_type(round(WithOverloadedRound()), str)
assert_type(round(WithOverloadedRound(), None), str)
assert_type(round(WithOverloadedRound(), 1), bytes)

View File

@@ -0,0 +1,55 @@
from __future__ import annotations
from typing import Any, List, Literal, Union
from typing_extensions import assert_type
class Foo:
def __add__(self, other: Any) -> Foo:
return Foo()
class Bar:
def __radd__(self, other: Any) -> Bar:
return Bar()
class Baz:
def __add__(self, other: Any) -> Baz:
return Baz()
def __radd__(self, other: Any) -> Baz:
return Baz()
literal_list: list[Literal[0, 1]] = [0, 1, 1]
assert_type(sum([2, 4]), int)
assert_type(sum([3, 5], 4), int)
assert_type(sum([True, False]), int)
assert_type(sum([True, False], True), int)
assert_type(sum(literal_list), int)
assert_type(sum([["foo"], ["bar"]], ["baz"]), List[str])
assert_type(sum([Foo(), Foo()], Foo()), Foo)
assert_type(sum([Baz(), Baz()]), Union[Baz, Literal[0]])
# mypy and pyright infer the types differently for these, so we can't use assert_type
# Just test that no error is emitted for any of these
sum([("foo",), ("bar", "baz")], ()) # mypy: `tuple[str, ...]`; pyright: `tuple[()] | tuple[str] | tuple[str, str]`
sum([5.6, 3.2]) # mypy: `float`; pyright: `float | Literal[0]`
sum([2.5, 5.8], 5) # mypy: `float`; pyright: `float | int`
# These all fail at runtime
sum("abcde") # type: ignore
sum([["foo"], ["bar"]]) # type: ignore
sum([("foo",), ("bar", "baz")]) # type: ignore
sum([Foo(), Foo()]) # type: ignore
sum([Bar(), Bar()], Bar()) # type: ignore
sum([Bar(), Bar()]) # type: ignore
# TODO: these pass pyright with the current stubs, but mypy erroneously emits an error:
# sum([3, Fraction(7, 22), complex(8, 0), 9.83])
# sum([3, Decimal('0.98')])

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
from typing import Tuple
from typing_extensions import assert_type
# Empty tuples, see #8275
class TupleSub(Tuple[int, ...]):
pass
assert_type(TupleSub(), TupleSub)
assert_type(TupleSub([1, 2, 3]), TupleSub)

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
import codecs
from typing_extensions import assert_type
assert_type(codecs.decode("x", "unicode-escape"), str)
assert_type(codecs.decode(b"x", "unicode-escape"), str)
assert_type(codecs.decode(b"x", "utf-8"), str)
codecs.decode("x", "utf-8") # type: ignore
assert_type(codecs.decode("ab", "hex"), bytes)
assert_type(codecs.decode(b"ab", "hex"), bytes)

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from collections.abc import Callable, Iterator
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from typing_extensions import assert_type
class Parent: ...
class Child(Parent): ...
def check_as_completed_covariance() -> None:
with ThreadPoolExecutor() as executor:
f1 = executor.submit(lambda: Parent())
f2 = executor.submit(lambda: Child())
fs: list[Future[Parent] | Future[Child]] = [f1, f2]
assert_type(as_completed(fs), Iterator[Future[Parent]])
for future in as_completed(fs):
assert_type(future.result(), Parent)
def check_future_invariance() -> None:
def execute_callback(callback: Callable[[], Parent], future: Future[Parent]) -> None:
future.set_result(callback())
fut: Future[Child] = Future()
execute_callback(lambda: Parent(), fut) # type: ignore
assert isinstance(fut.result(), Child)

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from contextlib import ExitStack
from typing_extensions import assert_type
# See issue #7961
class Thing(ExitStack):
pass
stack = ExitStack()
thing = Thing()
assert_type(stack.enter_context(Thing()), Thing)
assert_type(thing.enter_context(ExitStack()), ExitStack)
with stack as cm:
assert_type(cm, ExitStack)
with thing as cm2:
assert_type(cm2, Thing)

View File

@@ -0,0 +1,101 @@
from __future__ import annotations
import dataclasses as dc
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Tuple, Type, Union
from typing_extensions import Annotated, assert_type
if TYPE_CHECKING:
from _typeshed import DataclassInstance
@dc.dataclass
class Foo:
attr: str
assert_type(dc.fields(Foo), Tuple[dc.Field[Any], ...])
# Mypy correctly emits errors on these
# due to the fact it's a dataclass class, not an instance.
# Pyright, however, handles ClassVar members in protocols differently.
# See https://github.com/microsoft/pyright/issues/4339
#
# dc.asdict(Foo)
# dc.astuple(Foo)
# dc.replace(Foo)
# See #9723 for why we can't make this assertion
# if dc.is_dataclass(Foo):
# assert_type(Foo, Type[Foo])
f = Foo(attr="attr")
assert_type(dc.fields(f), Tuple[dc.Field[Any], ...])
assert_type(dc.asdict(f), Dict[str, Any])
assert_type(dc.astuple(f), Tuple[Any, ...])
assert_type(dc.replace(f, attr="new"), Foo)
if dc.is_dataclass(f):
# The inferred type doesn't change
# if it's already known to be a subtype of _DataclassInstance
assert_type(f, Foo)
def check_other_isdataclass_overloads(x: type, y: object) -> None:
# TODO: pyright correctly emits an error on this, but mypy does not -- why?
# dc.fields(x)
dc.fields(y) # type: ignore
dc.asdict(x) # type: ignore
dc.asdict(y) # type: ignore
dc.astuple(x) # type: ignore
dc.astuple(y) # type: ignore
dc.replace(x) # type: ignore
dc.replace(y) # type: ignore
if dc.is_dataclass(x):
assert_type(x, Type["DataclassInstance"])
assert_type(dc.fields(x), Tuple[dc.Field[Any], ...])
# Mypy correctly emits an error on these due to the fact
# that it's a dataclass class, not a dataclass instance.
# Pyright, however, handles ClassVar members in protocols differently.
# See https://github.com/microsoft/pyright/issues/4339
#
# dc.asdict(x)
# dc.astuple(x)
# dc.replace(x)
if dc.is_dataclass(y):
assert_type(y, Union["DataclassInstance", Type["DataclassInstance"]])
assert_type(dc.fields(y), Tuple[dc.Field[Any], ...])
# Mypy correctly emits an error on these due to the fact we don't know
# whether it's a dataclass class or a dataclass instance.
# Pyright, however, handles ClassVar members in protocols differently.
# See https://github.com/microsoft/pyright/issues/4339
#
# dc.asdict(y)
# dc.astuple(y)
# dc.replace(y)
if dc.is_dataclass(y) and not isinstance(y, type):
assert_type(y, "DataclassInstance")
assert_type(dc.fields(y), Tuple[dc.Field[Any], ...])
assert_type(dc.asdict(y), Dict[str, Any])
assert_type(dc.astuple(y), Tuple[Any, ...])
dc.replace(y)
# Regression test for #11653
D = dc.make_dataclass(
"D", [("a", Union[int, None]), "y", ("z", Annotated[FrozenSet[bytes], "metadata"], dc.field(default=frozenset({b"foo"})))]
)
# Check that it's inferred by the type checker as a class object of some kind
# (but don't assert the exact type that `D` is inferred as,
# in case a type checker decides to add some special-casing for
# `make_dataclass` in the future)
assert_type(D.__mro__, Tuple[type, ...])

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
import enum
import sys
from typing import Literal, Type
from typing_extensions import assert_type
A = enum.Enum("A", "spam eggs bacon")
B = enum.Enum("B", ["spam", "eggs", "bacon"])
C = enum.Enum("Bar", [("spam", 1), ("eggs", 2), ("bacon", 3)])
D = enum.Enum("Bar", {"spam": 1, "eggs": 2})
assert_type(A, Type[A])
assert_type(B, Type[B])
assert_type(C, Type[C])
assert_type(D, Type[D])
class EnumOfTuples(enum.Enum):
X = 1, 2, 3
Y = 4, 5, 6
assert_type(EnumOfTuples((1, 2, 3)), EnumOfTuples)
# TODO: ideally this test would pass:
#
# if sys.version_info >= (3, 12):
# assert_type(EnumOfTuples(1, 2, 3), EnumOfTuples)
if sys.version_info >= (3, 11):
class Foo(enum.StrEnum):
X = enum.auto()
assert_type(Foo.X, Literal[Foo.X])
assert_type(Foo.X.value, str)

View File

@@ -0,0 +1,67 @@
from __future__ import annotations
from functools import cached_property, wraps
from typing import Callable, TypeVar
from typing_extensions import ParamSpec, assert_type
P = ParamSpec("P")
T_co = TypeVar("T_co", covariant=True)
def my_decorator(func: Callable[P, T_co]) -> Callable[P, T_co]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co:
print(args)
return func(*args, **kwargs)
# verify that the wrapped function has all these attributes
wrapper.__annotations__ = func.__annotations__
wrapper.__doc__ = func.__doc__
wrapper.__module__ = func.__module__
wrapper.__name__ = func.__name__
wrapper.__qualname__ = func.__qualname__
return wrapper
class A:
def __init__(self, x: int):
self.x = x
@cached_property
def x(self) -> int:
return 0
assert_type(A(x=1).x, int)
class B:
@cached_property
def x(self) -> int:
return 0
def check_cached_property_settable(x: int) -> None:
b = B()
assert_type(b.x, int)
b.x = x
assert_type(b.x, int)
# https://github.com/python/typeshed/issues/10048
class Parent: ...
class Child(Parent): ...
class X:
@cached_property
def some(self) -> Parent:
return Parent()
class Y(X):
@cached_property
def some(self) -> Child: # safe override
return Child()

View File

@@ -0,0 +1,13 @@
import importlib.abc
import pathlib
import sys
import zipfile
# Assert that some Path classes are Traversable.
if sys.version_info >= (3, 9):
def traverse(t: importlib.abc.Traversable) -> None:
pass
traverse(pathlib.Path())
traverse(zipfile.Path(""))

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import sys
from _typeshed import StrPath
from os import PathLike
from pathlib import Path
from typing import Any
from zipfile import Path as ZipPath
if sys.version_info >= (3, 10):
from importlib.metadata._meta import SimplePath
# Simplified version of zipfile.Path
class MyPath:
@property
def parent(self) -> PathLike[str]: ... # undocumented
def read_text(self, encoding: str | None = ..., errors: str | None = ...) -> str: ...
def joinpath(self, *other: StrPath) -> MyPath: ...
def __truediv__(self, add: StrPath) -> MyPath: ...
if sys.version_info >= (3, 12):
def takes_simple_path(p: SimplePath[Any]) -> None: ...
else:
def takes_simple_path(p: SimplePath) -> None: ...
takes_simple_path(Path())
takes_simple_path(ZipPath(""))
takes_simple_path(MyPath())
takes_simple_path("some string") # type: ignore

View File

@@ -0,0 +1,6 @@
from gzip import GzipFile
from io import FileIO, TextIOWrapper
TextIOWrapper(FileIO(""))
TextIOWrapper(FileIO(13))
TextIOWrapper(GzipFile(""))

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
import logging
import logging.handlers
import multiprocessing
import queue
from typing import Any
# This pattern comes from the logging docs, and should therefore pass a type checker
# See https://docs.python.org/3/library/logging.html#logrecord-objects
old_factory = logging.getLogRecordFactory()
def record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord:
record = old_factory(*args, **kwargs)
record.custom_attribute = 0xDECAFBAD
return record
logging.setLogRecordFactory(record_factory)
# The logging docs say that QueueHandler and QueueListener can take "any queue-like object"
# We test that here (regression test for #10168)
logging.handlers.QueueHandler(queue.Queue())
logging.handlers.QueueHandler(queue.SimpleQueue())
logging.handlers.QueueHandler(multiprocessing.Queue())
logging.handlers.QueueListener(queue.Queue())
logging.handlers.QueueListener(queue.SimpleQueue())
logging.handlers.QueueListener(multiprocessing.Queue())

View File

@@ -0,0 +1,14 @@
from __future__ import annotations
from ctypes import c_char, c_float
from multiprocessing import Array, Value
from multiprocessing.sharedctypes import Synchronized, SynchronizedString
from typing_extensions import assert_type
string = Array(c_char, 12)
assert_type(string, SynchronizedString)
assert_type(string.value, bytes)
field = Value(c_float, 0.0)
assert_type(field, Synchronized[float])
field.value = 1.2

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from pathlib import Path, PureWindowsPath
if Path("asdf") == Path("asdf"):
...
# https://github.com/python/typeshed/issues/10661
# Provide a true positive error when comparing Path to str
# mypy should report a comparison-overlap error with --strict-equality,
# and pyright should report a reportUnnecessaryComparison error
if Path("asdf") == "asdf": # type: ignore
...
# Errors on comparison here are technically false positives. However, this comparison is a little
# interesting: it can never hold true on Posix, but could hold true on Windows. We should experiment
# with more accurate __new__, such that we only get an error for such comparisons on platforms
# where they can never hold true.
if PureWindowsPath("asdf") == Path("asdf"): # type: ignore
...

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
import mmap
import re
import typing as t
from typing_extensions import assert_type
def check_search(str_pat: re.Pattern[str], bytes_pat: re.Pattern[bytes]) -> None:
assert_type(str_pat.search("x"), t.Optional[t.Match[str]])
assert_type(bytes_pat.search(b"x"), t.Optional[t.Match[bytes]])
assert_type(bytes_pat.search(bytearray(b"x")), t.Optional[t.Match[bytes]])
assert_type(bytes_pat.search(mmap.mmap(0, 10)), t.Optional[t.Match[bytes]])
def check_search_with_AnyStr(pattern: re.Pattern[t.AnyStr], string: t.AnyStr) -> re.Match[t.AnyStr]:
"""See issue #9591"""
match = pattern.search(string)
if match is None:
raise ValueError(f"'{string!r}' does not match {pattern!r}")
return match
def check_no_ReadableBuffer_false_negatives() -> None:
re.compile("foo").search(bytearray(b"foo")) # type: ignore
re.compile("foo").search(mmap.mmap(0, 10)) # type: ignore

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
import sqlite3
from typing_extensions import assert_type
class MyConnection(sqlite3.Connection):
pass
# Default return-type is Connection.
assert_type(sqlite3.connect(":memory:"), sqlite3.Connection)
# Providing an alternate factory changes the return-type.
assert_type(sqlite3.connect(":memory:", factory=MyConnection), MyConnection)
# Provides a true positive error. When checking the connect() function,
# mypy should report an arg-type error for the factory argument.
with sqlite3.connect(":memory:", factory=None) as con: # type: ignore
pass
# The Connection class also accepts a `factory` arg but it does not affect
# the return-type. This use case is not idiomatic--connections should be
# established using the `connect()` function, not directly (as shown here).
assert_type(sqlite3.Connection(":memory:", factory=None), sqlite3.Connection)
assert_type(sqlite3.Connection(":memory:", factory=MyConnection), sqlite3.Connection)

View File

@@ -0,0 +1,13 @@
import tarfile
with tarfile.open("test.tar.xz", "w:xz") as tar:
pass
# Test with valid preset values
tarfile.open("test.tar.xz", "w:xz", preset=0)
tarfile.open("test.tar.xz", "w:xz", preset=5)
tarfile.open("test.tar.xz", "w:xz", preset=9)
# Test with invalid preset values
tarfile.open("test.tar.xz", "w:xz", preset=-1) # type: ignore
tarfile.open("test.tar.xz", "w:xz", preset=10) # type: ignore

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
import io
import sys
from tempfile import TemporaryFile, _TemporaryFileWrapper
from typing_extensions import assert_type
if sys.platform == "win32":
assert_type(TemporaryFile(), _TemporaryFileWrapper[bytes])
assert_type(TemporaryFile("w+"), _TemporaryFileWrapper[str])
assert_type(TemporaryFile("w+b"), _TemporaryFileWrapper[bytes])
assert_type(TemporaryFile("wb"), _TemporaryFileWrapper[bytes])
assert_type(TemporaryFile("rb"), _TemporaryFileWrapper[bytes])
assert_type(TemporaryFile("wb", 0), _TemporaryFileWrapper[bytes])
assert_type(TemporaryFile(mode="w+"), _TemporaryFileWrapper[str])
assert_type(TemporaryFile(mode="w+b"), _TemporaryFileWrapper[bytes])
assert_type(TemporaryFile(mode="wb"), _TemporaryFileWrapper[bytes])
assert_type(TemporaryFile(mode="rb"), _TemporaryFileWrapper[bytes])
assert_type(TemporaryFile(buffering=0), _TemporaryFileWrapper[bytes])
else:
assert_type(TemporaryFile(), io.BufferedRandom)
assert_type(TemporaryFile("w+"), io.TextIOWrapper)
assert_type(TemporaryFile("w+b"), io.BufferedRandom)
assert_type(TemporaryFile("wb"), io.BufferedWriter)
assert_type(TemporaryFile("rb"), io.BufferedReader)
assert_type(TemporaryFile("wb", 0), io.FileIO)
assert_type(TemporaryFile(mode="w+"), io.TextIOWrapper)
assert_type(TemporaryFile(mode="w+b"), io.BufferedRandom)
assert_type(TemporaryFile(mode="wb"), io.BufferedWriter)
assert_type(TemporaryFile(mode="rb"), io.BufferedReader)
assert_type(TemporaryFile(buffering=0), io.FileIO)

View File

@@ -0,0 +1,14 @@
from __future__ import annotations
import _threading_local
import threading
loc = threading.local()
loc.foo = 42
del loc.foo
loc.baz = ["spam", "eggs"]
del loc.baz
l2 = _threading_local.local()
l2.asdfasdf = 56
del l2.asdfasdf

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
import tkinter
import traceback
import types
def custom_handler(exc: type[BaseException], val: BaseException, tb: types.TracebackType | None) -> None:
print("oh no")
root = tkinter.Tk()
root.report_callback_exception = traceback.print_exception
root.report_callback_exception = custom_handler
def foo(x: int, y: str) -> None:
pass
root.after(1000, foo, 10, "lol")
root.after(1000, foo, 10, 10) # type: ignore
# Font size must be integer
label = tkinter.Label()
label.config(font=("", 12))
label.config(font=("", 12.34)) # type: ignore
label.config(font=("", 12, "bold"))
label.config(font=("", 12.34, "bold")) # type: ignore

View File

@@ -0,0 +1,173 @@
from __future__ import annotations
import unittest
from collections.abc import Iterator, Mapping
from datetime import datetime, timedelta
from decimal import Decimal
from fractions import Fraction
from typing import TypedDict
from typing_extensions import assert_type
from unittest.mock import MagicMock, Mock, patch
case = unittest.TestCase()
###
# Tests for assertAlmostEqual
###
case.assertAlmostEqual(1, 2.4)
case.assertAlmostEqual(2.4, 2.41)
case.assertAlmostEqual(Fraction(49, 50), Fraction(48, 50))
case.assertAlmostEqual(3.14, complex(5, 6))
case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), delta=timedelta(hours=1))
case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), None, "foo", timedelta(hours=1))
case.assertAlmostEqual(Decimal("1.1"), Decimal("1.11"))
case.assertAlmostEqual(2.4, 2.41, places=8)
case.assertAlmostEqual(2.4, 2.41, delta=0.02)
case.assertAlmostEqual(2.4, 2.41, None, "foo", 0.02)
case.assertAlmostEqual(2.4, 2.41, places=9, delta=0.02) # type: ignore
case.assertAlmostEqual("foo", "bar") # type: ignore
case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1)) # type: ignore
case.assertAlmostEqual(Decimal("0.4"), Fraction(1, 2)) # type: ignore
case.assertAlmostEqual(complex(2, 3), Decimal("0.9")) # type: ignore
###
# Tests for assertNotAlmostEqual
###
case.assertAlmostEqual(1, 2.4)
case.assertNotAlmostEqual(Fraction(49, 50), Fraction(48, 50))
case.assertAlmostEqual(3.14, complex(5, 6))
case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), delta=timedelta(hours=1))
case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), None, "foo", timedelta(hours=1))
case.assertNotAlmostEqual(2.4, 2.41, places=9, delta=0.02) # type: ignore
case.assertNotAlmostEqual("foo", "bar") # type: ignore
case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1)) # type: ignore
case.assertNotAlmostEqual(Decimal("0.4"), Fraction(1, 2)) # type: ignore
case.assertNotAlmostEqual(complex(2, 3), Decimal("0.9")) # type: ignore
###
# Tests for assertGreater
###
class Spam:
def __lt__(self, other: object) -> bool:
return True
class Eggs:
def __gt__(self, other: object) -> bool:
return True
class Ham:
def __lt__(self, other: Ham) -> bool:
if not isinstance(other, Ham):
return NotImplemented
return True
class Bacon:
def __gt__(self, other: Bacon) -> bool:
if not isinstance(other, Bacon):
return NotImplemented
return True
case.assertGreater(5.8, 3)
case.assertGreater(Decimal("4.5"), Fraction(3, 2))
case.assertGreater(Fraction(3, 2), 0.9)
case.assertGreater(Eggs(), object())
case.assertGreater(object(), Spam())
case.assertGreater(Ham(), Ham())
case.assertGreater(Bacon(), Bacon())
case.assertGreater(object(), object()) # type: ignore
case.assertGreater(datetime(1999, 1, 2), 1) # type: ignore
case.assertGreater(Spam(), Eggs()) # type: ignore
case.assertGreater(Ham(), Bacon()) # type: ignore
case.assertGreater(Bacon(), Ham()) # type: ignore
###
# Tests for assertDictEqual
###
class TD1(TypedDict):
x: int
y: str
class TD2(TypedDict):
a: bool
b: bool
class MyMapping(Mapping[str, int]):
def __getitem__(self, __key: str) -> int:
return 42
def __iter__(self) -> Iterator[str]:
return iter([])
def __len__(self) -> int:
return 0
td1: TD1 = {"x": 1, "y": "foo"}
td2: TD2 = {"a": True, "b": False}
m = MyMapping()
case.assertDictEqual({}, {})
case.assertDictEqual({"x": 1, "y": 2}, {"x": 1, "y": 2})
case.assertDictEqual({"x": 1, "y": "foo"}, {"y": "foo", "x": 1})
case.assertDictEqual({"x": 1}, {})
case.assertDictEqual({}, {"x": 1})
case.assertDictEqual({1: "x"}, {"y": 222})
case.assertDictEqual({1: "x"}, td1)
case.assertDictEqual(td1, {1: "x"})
case.assertDictEqual(td1, td2)
case.assertDictEqual(1, {}) # type: ignore
case.assertDictEqual({}, 1) # type: ignore
# These should fail, but don't due to TypedDict limitations:
# case.assertDictEqual(m, {"": 0}) # xtype: ignore
# case.assertDictEqual({"": 0}, m) # xtype: ignore
###
# Tests for mock.patch
###
@patch("sys.exit")
def f_default_new(i: int, mock: MagicMock) -> str:
return "asdf"
@patch("sys.exit", new=42)
def f_explicit_new(i: int) -> str:
return "asdf"
assert_type(f_default_new(1), str)
f_default_new("a") # Not an error due to ParamSpec limitations
assert_type(f_explicit_new(1), str)
f_explicit_new("a") # type: ignore[arg-type]
@patch("sys.exit", new=Mock())
class TestXYZ(unittest.TestCase):
attr: int = 5
@staticmethod
def method() -> int:
return 123
assert_type(TestXYZ.attr, int)
assert_type(TestXYZ.method(), int)

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
import sys
from typing_extensions import assert_type
from xml.dom.minidom import Document
document = Document()
assert_type(document.toxml(), str)
assert_type(document.toxml(encoding=None), str)
assert_type(document.toxml(encoding="UTF8"), bytes)
assert_type(document.toxml("UTF8"), bytes)
if sys.version_info >= (3, 9):
assert_type(document.toxml(standalone=True), str)
assert_type(document.toxml("UTF8", True), bytes)
assert_type(document.toxml(encoding="UTF8", standalone=True), bytes)
# Because toprettyxml can mix positional and keyword variants of the "encoding" argument, which
# determines the return type, the proper stub typing isn't immediately obvious. This is a basic
# brute-force sanity check.
# Test cases like toxml
assert_type(document.toprettyxml(), str)
assert_type(document.toprettyxml(encoding=None), str)
assert_type(document.toprettyxml(encoding="UTF8"), bytes)
if sys.version_info >= (3, 9):
assert_type(document.toprettyxml(standalone=True), str)
assert_type(document.toprettyxml(encoding="UTF8", standalone=True), bytes)
# Test cases unique to toprettyxml
assert_type(document.toprettyxml(" "), str)
assert_type(document.toprettyxml(" ", "\r\n"), str)
assert_type(document.toprettyxml(" ", "\r\n", "UTF8"), bytes)
if sys.version_info >= (3, 9):
assert_type(document.toprettyxml(" ", "\r\n", "UTF8", True), bytes)
assert_type(document.toprettyxml(" ", "\r\n", standalone=True), str)

View File

@@ -0,0 +1,69 @@
"""
Tests for `defaultdict.__or__` and `defaultdict.__ror__`.
These methods were only added in py39.
"""
from __future__ import annotations
import os
import sys
from collections import defaultdict
from typing import Mapping, TypeVar, Union
from typing_extensions import Self, assert_type
_KT = TypeVar("_KT")
_VT = TypeVar("_VT")
if sys.version_info >= (3, 9):
class CustomDefaultDictSubclass(defaultdict[_KT, _VT]):
pass
class CustomMappingWithDunderOr(Mapping[_KT, _VT]):
def __or__(self, other: Mapping[_KT, _VT]) -> dict[_KT, _VT]:
return {}
def __ror__(self, other: Mapping[_KT, _VT]) -> dict[_KT, _VT]:
return {}
def __ior__(self, other: Mapping[_KT, _VT]) -> Self:
return self
def test_defaultdict_dot_or(
a: defaultdict[int, int],
b: CustomDefaultDictSubclass[int, int],
c: defaultdict[str, str],
d: Mapping[int, int],
e: CustomMappingWithDunderOr[str, str],
) -> None:
assert_type(a | b, defaultdict[int, int])
# In contrast to `dict.__or__`, `defaultdict.__or__` returns `Self` if called on a subclass of `defaultdict`:
assert_type(b | a, CustomDefaultDictSubclass[int, int])
assert_type(a | c, defaultdict[Union[int, str], Union[int, str]])
# arbitrary mappings are not accepted by `defaultdict.__or__`;
# it has to be a subclass of `dict`
a | d # type: ignore
# but Mappings such as `os._Environ` or `CustomMappingWithDunderOr`,
# which define `__ror__` methods that accept `dict`, are fine
# (`os._Environ.__(r)or__` always returns `dict`, even if a `defaultdict` is passed):
assert_type(a | os.environ, dict[Union[str, int], Union[str, int]])
assert_type(os.environ | a, dict[Union[str, int], Union[str, int]])
assert_type(c | os.environ, dict[str, str])
assert_type(c | e, dict[str, str])
assert_type(os.environ | c, dict[str, str])
assert_type(e | c, dict[str, str])
e |= c
e |= a # type: ignore
# TODO: this test passes mypy, but fails pyright for some reason:
# c |= e
c |= a # type: ignore

View File

@@ -0,0 +1,6 @@
from email.headerregistry import Address
from email.message import EmailMessage
msg = EmailMessage()
msg["To"] = "receiver@example.com"
msg["From"] = Address("Sender Name", "sender", "example.com")

View File

@@ -0,0 +1,410 @@
"""Type-annotated versions of the recipes from the itertools docs.
These are all meant to be examples of idiomatic itertools usage,
so they should all type-check without error.
"""
from __future__ import annotations
import collections
import math
import operator
import sys
from itertools import chain, combinations, count, cycle, filterfalse, groupby, islice, product, repeat, starmap, tee, zip_longest
from typing import (
Any,
Callable,
Collection,
Hashable,
Iterable,
Iterator,
Literal,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from typing_extensions import TypeAlias, TypeVarTuple, Unpack
_T = TypeVar("_T")
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_HashableT = TypeVar("_HashableT", bound=Hashable)
_Ts = TypeVarTuple("_Ts")
def take(n: int, iterable: Iterable[_T]) -> list[_T]:
"Return first n items of the iterable as a list"
return list(islice(iterable, n))
# Note: the itertools docs uses the parameter name "iterator",
# but the function actually accepts any iterable
# as its second argument
def prepend(value: _T1, iterator: Iterable[_T2]) -> Iterator[_T1 | _T2]:
"Prepend a single value in front of an iterator"
# prepend(1, [2, 3, 4]) --> 1 2 3 4
return chain([value], iterator)
def tabulate(function: Callable[[int], _T], start: int = 0) -> Iterator[_T]:
"Return function(0), function(1), ..."
return map(function, count(start))
def repeatfunc(func: Callable[[Unpack[_Ts]], _T], times: int | None = None, *args: Unpack[_Ts]) -> Iterator[_T]:
"""Repeat calls to func with specified arguments.
Example: repeatfunc(random.random)
"""
if times is None:
return starmap(func, repeat(args))
return starmap(func, repeat(args, times))
def flatten(list_of_lists: Iterable[Iterable[_T]]) -> Iterator[_T]:
"Flatten one level of nesting"
return chain.from_iterable(list_of_lists)
def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]:
"Returns the sequence elements n times"
return chain.from_iterable(repeat(tuple(iterable), n))
def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]:
"Return an iterator over the last n items"
# tail(3, 'ABCDEFG') --> E F G
return iter(collections.deque(iterable, maxlen=n))
# This function *accepts* any iterable,
# but it only *makes sense* to use it with an iterator
def consume(iterator: Iterator[object], n: int | None = None) -> None:
"Advance the iterator n-steps ahead. If n is None, consume entirely."
# Use functions that consume iterators at C speed.
if n is None:
# feed the entire iterator into a zero-length deque
collections.deque(iterator, maxlen=0)
else:
# advance to the empty slice starting at position n
next(islice(iterator, n, n), None)
@overload
def nth(iterable: Iterable[_T], n: int, default: None = None) -> _T | None: ...
@overload
def nth(iterable: Iterable[_T], n: int, default: _T1) -> _T | _T1: ...
def nth(iterable: Iterable[object], n: int, default: object = None) -> object:
"Returns the nth item or a default value"
return next(islice(iterable, n, None), default)
@overload
def quantify(iterable: Iterable[object]) -> int: ...
@overload
def quantify(iterable: Iterable[_T], pred: Callable[[_T], bool]) -> int: ...
def quantify(iterable: Iterable[object], pred: Callable[[Any], bool] = bool) -> int:
"Given a predicate that returns True or False, count the True results."
return sum(map(pred, iterable))
@overload
def first_true(
iterable: Iterable[_T], default: Literal[False] = False, pred: Callable[[_T], bool] | None = None
) -> _T | Literal[False]: ...
@overload
def first_true(iterable: Iterable[_T], default: _T1, pred: Callable[[_T], bool] | None = None) -> _T | _T1: ...
def first_true(iterable: Iterable[object], default: object = False, pred: Callable[[Any], bool] | None = None) -> object:
"""Returns the first true value in the iterable.
If no true value is found, returns *default*
If *pred* is not None, returns the first item
for which pred(item) is true.
"""
# first_true([a,b,c], x) --> a or b or c or x
# first_true([a,b], x, f) --> a if f(a) else b if f(b) else x
return next(filter(pred, iterable), default)
_ExceptionOrExceptionTuple: TypeAlias = Union[Type[BaseException], Tuple[Type[BaseException], ...]]
@overload
def iter_except(func: Callable[[], _T], exception: _ExceptionOrExceptionTuple, first: None = None) -> Iterator[_T]: ...
@overload
def iter_except(
func: Callable[[], _T], exception: _ExceptionOrExceptionTuple, first: Callable[[], _T1]
) -> Iterator[_T | _T1]: ...
def iter_except(
func: Callable[[], object], exception: _ExceptionOrExceptionTuple, first: Callable[[], object] | None = None
) -> Iterator[object]:
"""Call a function repeatedly until an exception is raised.
Converts a call-until-exception interface to an iterator interface.
Like builtins.iter(func, sentinel) but uses an exception instead
of a sentinel to end the loop.
Examples:
iter_except(functools.partial(heappop, h), IndexError) # priority queue iterator
iter_except(d.popitem, KeyError) # non-blocking dict iterator
iter_except(d.popleft, IndexError) # non-blocking deque iterator
iter_except(q.get_nowait, Queue.Empty) # loop over a producer Queue
iter_except(s.pop, KeyError) # non-blocking set iterator
"""
try:
if first is not None:
yield first() # For database APIs needing an initial cast to db.first()
while True:
yield func()
except exception:
pass
def sliding_window(iterable: Iterable[_T], n: int) -> Iterator[tuple[_T, ...]]:
# sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG
it = iter(iterable)
window = collections.deque(islice(it, n - 1), maxlen=n)
for x in it:
window.append(x)
yield tuple(window)
def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]:
"roundrobin('ABC', 'D', 'EF') --> A D E B F C"
# Recipe credited to George Sakkis
num_active = len(iterables)
nexts: Iterator[Callable[[], _T]] = cycle(iter(it).__next__ for it in iterables)
while num_active:
try:
for next in nexts:
yield next()
except StopIteration:
# Remove the iterator we just exhausted from the cycle.
num_active -= 1
nexts = cycle(islice(nexts, num_active))
def partition(pred: Callable[[_T], bool], iterable: Iterable[_T]) -> tuple[Iterator[_T], Iterator[_T]]:
"""Partition entries into false entries and true entries.
If *pred* is slow, consider wrapping it with functools.lru_cache().
"""
# partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
t1, t2 = tee(iterable)
return filterfalse(pred, t1), filter(pred, t2)
def subslices(seq: Sequence[_T]) -> Iterator[Sequence[_T]]:
"Return all contiguous non-empty subslices of a sequence"
# subslices('ABCD') --> A AB ABC ABCD B BC BCD C CD D
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
return map(operator.getitem, repeat(seq), slices)
def before_and_after(predicate: Callable[[_T], bool], it: Iterable[_T]) -> tuple[Iterator[_T], Iterator[_T]]:
"""Variant of takewhile() that allows complete
access to the remainder of the iterator.
>>> it = iter('ABCdEfGhI')
>>> all_upper, remainder = before_and_after(str.isupper, it)
>>> ''.join(all_upper)
'ABC'
>>> ''.join(remainder) # takewhile() would lose the 'd'
'dEfGhI'
Note that the first iterator must be fully
consumed before the second iterator can
generate valid results.
"""
it = iter(it)
transition: list[_T] = []
def true_iterator() -> Iterator[_T]:
for elem in it:
if predicate(elem):
yield elem
else:
transition.append(elem)
return
def remainder_iterator() -> Iterator[_T]:
yield from transition
yield from it
return true_iterator(), remainder_iterator()
@overload
def unique_everseen(iterable: Iterable[_HashableT], key: None = None) -> Iterator[_HashableT]: ...
@overload
def unique_everseen(iterable: Iterable[_T], key: Callable[[_T], Hashable]) -> Iterator[_T]: ...
def unique_everseen(iterable: Iterable[_T], key: Callable[[_T], Hashable] | None = None) -> Iterator[_T]:
"List unique elements, preserving order. Remember all elements ever seen."
# unique_everseen('AAAABBBCCDAABBB') --> A B C D
# unique_everseen('ABBcCAD', str.lower) --> A B c D
seen: set[Hashable] = set()
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen.add(element)
yield element
# For order preserving deduplication,
# a faster but non-lazy solution is:
# yield from dict.fromkeys(iterable)
else:
for element in iterable:
k = key(element)
if k not in seen:
seen.add(k)
yield element
# For use cases that allow the last matching element to be returned,
# a faster but non-lazy solution is:
# t1, t2 = tee(iterable)
# yield from dict(zip(map(key, t1), t2)).values()
# Slightly adapted from the docs recipe; a one-liner was a bit much for pyright
def unique_justseen(iterable: Iterable[_T], key: Callable[[_T], bool] | None = None) -> Iterator[_T]:
"List unique elements, preserving order. Remember only the element just seen."
# unique_justseen('AAAABBBCCDAABBB') --> A B C D A B
# unique_justseen('ABBcCAD', str.lower) --> A B c A D
g: groupby[_T | bool, _T] = groupby(iterable, key)
return map(next, map(operator.itemgetter(1), g))
def powerset(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]:
"powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
def polynomial_derivative(coefficients: Sequence[float]) -> list[float]:
"""Compute the first derivative of a polynomial.
f(x) = x³ -4x² -17x + 60
f'(x) = 3x² -8x -17
"""
# polynomial_derivative([1, -4, -17, 60]) -> [3, -8, -17]
n = len(coefficients)
powers = reversed(range(1, n))
return list(map(operator.mul, coefficients, powers))
def nth_combination(iterable: Iterable[_T], r: int, index: int) -> tuple[_T, ...]:
"Equivalent to list(combinations(iterable, r))[index]"
pool = tuple(iterable)
n = len(pool)
c = math.comb(n, r)
if index < 0:
index += c
if index < 0 or index >= c:
raise IndexError
result: list[_T] = []
while r:
c, n, r = c * r // n, n - 1, r - 1
while index >= c:
index -= c
c, n = c * (n - r) // n, n - 1
result.append(pool[-1 - n])
return tuple(result)
if sys.version_info >= (3, 10):
@overload
def grouper(
iterable: Iterable[_T], n: int, *, incomplete: Literal["fill"] = "fill", fillvalue: None = None
) -> Iterator[tuple[_T | None, ...]]: ...
@overload
def grouper(
iterable: Iterable[_T], n: int, *, incomplete: Literal["fill"] = "fill", fillvalue: _T1
) -> Iterator[tuple[_T | _T1, ...]]: ...
@overload
def grouper(
iterable: Iterable[_T], n: int, *, incomplete: Literal["strict", "ignore"], fillvalue: None = None
) -> Iterator[tuple[_T, ...]]: ...
def grouper(
iterable: Iterable[object], n: int, *, incomplete: Literal["fill", "strict", "ignore"] = "fill", fillvalue: object = None
) -> Iterator[tuple[object, ...]]:
"Collect data into non-overlapping fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
# grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError
# grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
args = [iter(iterable)] * n
if incomplete == "fill":
return zip_longest(*args, fillvalue=fillvalue)
if incomplete == "strict":
return zip(*args, strict=True)
if incomplete == "ignore":
return zip(*args)
else:
raise ValueError("Expected fill, strict, or ignore")
def transpose(it: Iterable[Iterable[_T]]) -> Iterator[tuple[_T, ...]]:
"Swap the rows and columns of the input."
# transpose([(1, 2, 3), (11, 22, 33)]) --> (1, 11) (2, 22) (3, 33)
return zip(*it, strict=True)
if sys.version_info >= (3, 12):
from itertools import batched
def sum_of_squares(it: Iterable[float]) -> float:
"Add up the squares of the input values."
# sum_of_squares([10, 20, 30]) -> 1400
return math.sumprod(*tee(it))
def convolve(signal: Iterable[float], kernel: Iterable[float]) -> Iterator[float]:
"""Discrete linear convolution of two iterables.
The kernel is fully consumed before the calculations begin.
The signal is consumed lazily and can be infinite.
Convolutions are mathematically commutative.
If the signal and kernel are swapped,
the output will be the same.
Article: https://betterexplained.com/articles/intuitive-convolution/
Video: https://www.youtube.com/watch?v=KuXjwB4LzSA
"""
# convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur)
# convolve(data, [1/2, 0, -1/2]) --> 1st derivative estimate
# convolve(data, [1, -2, 1]) --> 2nd derivative estimate
kernel = tuple(kernel)[::-1]
n = len(kernel)
padded_signal = chain(repeat(0, n - 1), signal, repeat(0, n - 1))
windowed_signal = sliding_window(padded_signal, n)
return map(math.sumprod, repeat(kernel), windowed_signal)
def polynomial_eval(coefficients: Sequence[float], x: float) -> float:
"""Evaluate a polynomial at a specific value.
Computes with better numeric stability than Horner's method.
"""
# Evaluate x³ -4x² -17x + 60 at x = 2.5
# polynomial_eval([1, -4, -17, 60], x=2.5) --> 8.125
n = len(coefficients)
if not n:
return type(x)(0)
powers = map(pow, repeat(x), reversed(range(n)))
return math.sumprod(coefficients, powers)
def matmul(m1: Sequence[Collection[float]], m2: Sequence[Collection[float]]) -> Iterator[tuple[float, ...]]:
"Multiply two matrices."
# matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) --> (49, 80), (41, 60)
n = len(m2[0])
return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Any, Union
from typing_extensions import assert_type
def check_setdefault_method() -> None:
d: dict[int, str] = {}
d2: dict[int, str | None] = {}
d3: dict[int, Any] = {}
d.setdefault(1) # type: ignore
assert_type(d.setdefault(1, "x"), str)
assert_type(d2.setdefault(1), Union[str, None])
assert_type(d2.setdefault(1, None), Union[str, None])
assert_type(d2.setdefault(1, "x"), Union[str, None])
assert_type(d3.setdefault(1), Union[Any, None])
assert_type(d3.setdefault(1, "x"), Any)

View File

@@ -0,0 +1,14 @@
# pyright: reportWildcardImportFromLibrary=false
"""
This tests that star imports work when using "all += " syntax.
"""
from __future__ import annotations
import sys
from typing import *
from zipfile import *
if sys.version_info >= (3, 9):
x: Annotated[int, 42]
p: Path

View File

@@ -0,0 +1,16 @@
from __future__ import annotations
import typing as t
KT = t.TypeVar("KT")
class MyKeysView(t.KeysView[KT]):
pass
d: dict[t.Any, t.Any] = {}
dict_keys = type(d.keys())
# This should not cause an error like `Member "register" is unknown`:
MyKeysView.register(dict_keys)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
import mmap
from typing import IO, AnyStr
def check_write(io_bytes: IO[bytes], io_str: IO[str], io_anystr: IO[AnyStr], any_str: AnyStr, buf: mmap.mmap) -> None:
io_bytes.write(b"")
io_bytes.write(buf)
io_bytes.write("") # type: ignore
io_bytes.write(any_str) # type: ignore
io_str.write(b"") # type: ignore
io_str.write(buf) # type: ignore
io_str.write("")
io_str.write(any_str) # type: ignore
io_anystr.write(b"") # type: ignore
io_anystr.write(buf) # type: ignore
io_anystr.write("") # type: ignore
io_anystr.write(any_str)