Make itertools.starmap covariant (#11037)

This commit is contained in:
Alex Waygood
2023-11-29 11:10:26 +00:00
committed by GitHub
parent e7c57b5a6d
commit 09668963a1
2 changed files with 12 additions and 5 deletions

View File

@@ -101,10 +101,10 @@ class islice(Iterator[_T]):
def __iter__(self) -> Self: ...
def __next__(self) -> _T: ...
class starmap(Iterator[_T]):
def __init__(self, __function: Callable[..., _T], __iterable: Iterable[Iterable[Any]]) -> None: ...
class starmap(Iterator[_T_co]):
def __new__(cls, __function: Callable[..., _T], __iterable: Iterable[Iterable[Any]]) -> starmap[_T]: ...
def __iter__(self) -> Self: ...
def __next__(self) -> _T: ...
def __next__(self) -> _T_co: ...
class takewhile(Iterator[_T]):
def __init__(self, __predicate: _Predicate[_T], __iterable: Iterable[_T]) -> None: ...

View File

@@ -9,8 +9,8 @@ import collections
import math
import operator
import sys
from itertools import chain, combinations, count, cycle, filterfalse, groupby, islice, repeat, starmap, tee, zip_longest
from typing import Any, Callable, Hashable, Iterable, Iterator, Sequence, Tuple, Type, TypeVar, Union, overload
from itertools import chain, combinations, count, cycle, filterfalse, groupby, islice, product, repeat, starmap, tee, zip_longest
from typing import Any, Callable, Collection, Hashable, Iterable, Iterator, Sequence, Tuple, Type, TypeVar, Union, overload
from typing_extensions import Literal, TypeAlias, TypeVarTuple, Unpack
_T = TypeVar("_T")
@@ -363,6 +363,7 @@ if sys.version_info >= (3, 10):
if sys.version_info >= (3, 12):
from itertools import batched
def sum_of_squares(it: Iterable[float]) -> float:
"Add up the squares of the input values."
@@ -399,3 +400,9 @@ if sys.version_info >= (3, 12):
return type(x)(0)
powers = map(pow, repeat(x), reversed(range(n)))
return math.sumprod(coefficients, powers)
def matmul(m1: Sequence[Collection[float]], m2: Sequence[Collection[float]]) -> Iterator[tuple[float, ...]]:
"Multiply two matrices."
# matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) --> (49, 80), (41, 60)
n = len(m2[0])
return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)