Make itertools.starmap covariant (#11037)

This commit is contained in:
Alex Waygood
2023-11-29 11:10:26 +00:00
committed by GitHub
parent e7c57b5a6d
commit 09668963a1
2 changed files with 12 additions and 5 deletions

View File

@@ -9,8 +9,8 @@ import collections
import math
import operator
import sys
from itertools import chain, combinations, count, cycle, filterfalse, groupby, islice, repeat, starmap, tee, zip_longest
from typing import Any, Callable, Hashable, Iterable, Iterator, Sequence, Tuple, Type, TypeVar, Union, overload
from itertools import chain, combinations, count, cycle, filterfalse, groupby, islice, product, repeat, starmap, tee, zip_longest
from typing import Any, Callable, Collection, Hashable, Iterable, Iterator, Sequence, Tuple, Type, TypeVar, Union, overload
from typing_extensions import Literal, TypeAlias, TypeVarTuple, Unpack
_T = TypeVar("_T")
@@ -363,6 +363,7 @@ if sys.version_info >= (3, 10):
if sys.version_info >= (3, 12):
from itertools import batched
def sum_of_squares(it: Iterable[float]) -> float:
"Add up the squares of the input values."
@@ -399,3 +400,9 @@ if sys.version_info >= (3, 12):
return type(x)(0)
powers = map(pow, repeat(x), reversed(range(n)))
return math.sumprod(coefficients, powers)
def matmul(m1: Sequence[Collection[float]], m2: Sequence[Collection[float]]) -> Iterator[tuple[float, ...]]:
"Multiply two matrices."
# matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) --> (49, 80), (41, 60)
n = len(m2[0])
return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)