From 09668963a1d7f5f524a6fbe30fb9828518363ad8 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Wed, 29 Nov 2023 11:10:26 +0000 Subject: [PATCH] Make `itertools.starmap` covariant (#11037) --- stdlib/itertools.pyi | 6 +++--- .../stdlib/itertools/check_itertools_recipes.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/stdlib/itertools.pyi b/stdlib/itertools.pyi index de49ee3c5..1bc0b2ec7 100644 --- a/stdlib/itertools.pyi +++ b/stdlib/itertools.pyi @@ -101,10 +101,10 @@ class islice(Iterator[_T]): def __iter__(self) -> Self: ... def __next__(self) -> _T: ... -class starmap(Iterator[_T]): - def __init__(self, __function: Callable[..., _T], __iterable: Iterable[Iterable[Any]]) -> None: ... +class starmap(Iterator[_T_co]): + def __new__(cls, __function: Callable[..., _T], __iterable: Iterable[Iterable[Any]]) -> starmap[_T]: ... def __iter__(self) -> Self: ... - def __next__(self) -> _T: ... + def __next__(self) -> _T_co: ... class takewhile(Iterator[_T]): def __init__(self, __predicate: _Predicate[_T], __iterable: Iterable[_T]) -> None: ... diff --git a/test_cases/stdlib/itertools/check_itertools_recipes.py b/test_cases/stdlib/itertools/check_itertools_recipes.py index 340811eec..9f31ba70f 100644 --- a/test_cases/stdlib/itertools/check_itertools_recipes.py +++ b/test_cases/stdlib/itertools/check_itertools_recipes.py @@ -9,8 +9,8 @@ import collections import math import operator import sys -from itertools import chain, combinations, count, cycle, filterfalse, groupby, islice, repeat, starmap, tee, zip_longest -from typing import Any, Callable, Hashable, Iterable, Iterator, Sequence, Tuple, Type, TypeVar, Union, overload +from itertools import chain, combinations, count, cycle, filterfalse, groupby, islice, product, repeat, starmap, tee, zip_longest +from typing import Any, Callable, Collection, Hashable, Iterable, Iterator, Sequence, Tuple, Type, TypeVar, Union, overload from typing_extensions import Literal, TypeAlias, TypeVarTuple, Unpack _T = TypeVar("_T") @@ -363,6 +363,7 @@ if sys.version_info >= (3, 10): if sys.version_info >= (3, 12): + from itertools import batched def sum_of_squares(it: Iterable[float]) -> float: "Add up the squares of the input values." @@ -399,3 +400,9 @@ if sys.version_info >= (3, 12): return type(x)(0) powers = map(pow, repeat(x), reversed(range(n))) return math.sumprod(coefficients, powers) + + def matmul(m1: Sequence[Collection[float]], m2: Sequence[Collection[float]]) -> Iterator[tuple[float, ...]]: + "Multiply two matrices." + # matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) --> (49, 80), (41, 60) + n = len(m2[0]) + return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)