mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-08 13:04:46 +08:00
Add several itertools recipes to the test_cases directory (#10992)
This commit is contained in:
371
test_cases/stdlib/itertools/check_itertools_recipes.py
Normal file
371
test_cases/stdlib/itertools/check_itertools_recipes.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""Type-annotated versions of the recipes from the itertools docs.
|
||||
|
||||
These are all meant to be examples of idiomatic itertools usage,
|
||||
so they should all type-check without error.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import math
|
||||
import operator
|
||||
import sys
|
||||
from itertools import chain, combinations, count, cycle, filterfalse, islice, repeat, starmap, tee, zip_longest
|
||||
from typing import Any, Callable, Hashable, Iterable, Iterator, Sequence, Tuple, Type, TypeVar, Union, overload
|
||||
from typing_extensions import Literal, TypeAlias, TypeVarTuple, Unpack
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_T1 = TypeVar("_T1")
|
||||
_T2 = TypeVar("_T2")
|
||||
_HashableT = TypeVar("_HashableT", bound=Hashable)
|
||||
_Ts = TypeVarTuple("_Ts")
|
||||
|
||||
|
||||
def take(n: int, iterable: Iterable[_T]) -> list[_T]:
|
||||
"Return first n items of the iterable as a list"
|
||||
return list(islice(iterable, n))
|
||||
|
||||
|
||||
# Note: the itertools docs uses the parameter name "iterator",
|
||||
# but the function actually accepts any iterable
|
||||
# as its second argument
|
||||
def prepend(value: _T1, iterator: Iterable[_T2]) -> Iterator[_T1 | _T2]:
|
||||
"Prepend a single value in front of an iterator"
|
||||
# prepend(1, [2, 3, 4]) --> 1 2 3 4
|
||||
return chain([value], iterator)
|
||||
|
||||
|
||||
def tabulate(function: Callable[[int], _T], start: int = 0) -> Iterator[_T]:
|
||||
"Return function(0), function(1), ..."
|
||||
return map(function, count(start))
|
||||
|
||||
|
||||
def repeatfunc(func: Callable[[Unpack[_Ts]], _T], times: int | None = None, *args: Unpack[_Ts]) -> Iterator[_T]:
|
||||
"""Repeat calls to func with specified arguments.
|
||||
|
||||
Example: repeatfunc(random.random)
|
||||
"""
|
||||
if times is None:
|
||||
return starmap(func, repeat(args))
|
||||
return starmap(func, repeat(args, times))
|
||||
|
||||
|
||||
def flatten(list_of_lists: Iterable[Iterable[_T]]) -> Iterator[_T]:
|
||||
"Flatten one level of nesting"
|
||||
return chain.from_iterable(list_of_lists)
|
||||
|
||||
|
||||
def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]:
|
||||
"Returns the sequence elements n times"
|
||||
return chain.from_iterable(repeat(tuple(iterable), n))
|
||||
|
||||
|
||||
def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]:
|
||||
"Return an iterator over the last n items"
|
||||
# tail(3, 'ABCDEFG') --> E F G
|
||||
return iter(collections.deque(iterable, maxlen=n))
|
||||
|
||||
|
||||
# This function *accepts* any iterable,
|
||||
# but it only *makes sense* to use it with an iterator
|
||||
def consume(iterator: Iterator[object], n: int | None = None) -> None:
|
||||
"Advance the iterator n-steps ahead. If n is None, consume entirely."
|
||||
# Use functions that consume iterators at C speed.
|
||||
if n is None:
|
||||
# feed the entire iterator into a zero-length deque
|
||||
collections.deque(iterator, maxlen=0)
|
||||
else:
|
||||
# advance to the empty slice starting at position n
|
||||
next(islice(iterator, n, n), None)
|
||||
|
||||
|
||||
@overload
|
||||
def nth(iterable: Iterable[_T], n: int, default: None = None) -> _T | None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def nth(iterable: Iterable[_T], n: int, default: _T1) -> _T | _T1:
|
||||
...
|
||||
|
||||
|
||||
def nth(iterable: Iterable[object], n: int, default: object = None) -> object:
|
||||
"Returns the nth item or a default value"
|
||||
return next(islice(iterable, n, None), default)
|
||||
|
||||
|
||||
@overload
|
||||
def quantify(iterable: Iterable[object]) -> int:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def quantify(iterable: Iterable[_T], pred: Callable[[_T], bool]) -> int:
|
||||
...
|
||||
|
||||
|
||||
def quantify(iterable: Iterable[object], pred: Callable[[Any], bool] = bool) -> int:
|
||||
"Given a predicate that returns True or False, count the True results."
|
||||
return sum(map(pred, iterable))
|
||||
|
||||
|
||||
@overload
|
||||
def first_true(
|
||||
iterable: Iterable[_T], default: Literal[False] = False, pred: Callable[[_T], bool] | None = None
|
||||
) -> _T | Literal[False]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def first_true(iterable: Iterable[_T], default: _T1, pred: Callable[[_T], bool] | None = None) -> _T | _T1:
|
||||
...
|
||||
|
||||
|
||||
def first_true(iterable: Iterable[object], default: object = False, pred: Callable[[Any], bool] | None = None) -> object:
|
||||
"""Returns the first true value in the iterable.
|
||||
If no true value is found, returns *default*
|
||||
If *pred* is not None, returns the first item
|
||||
for which pred(item) is true.
|
||||
"""
|
||||
# first_true([a,b,c], x) --> a or b or c or x
|
||||
# first_true([a,b], x, f) --> a if f(a) else b if f(b) else x
|
||||
return next(filter(pred, iterable), default)
|
||||
|
||||
|
||||
_ExceptionOrExceptionTuple: TypeAlias = Union[Type[BaseException], Tuple[Type[BaseException], ...]]
|
||||
|
||||
|
||||
@overload
|
||||
def iter_except(func: Callable[[], _T], exception: _ExceptionOrExceptionTuple, first: None = None) -> Iterator[_T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def iter_except(func: Callable[[], _T], exception: _ExceptionOrExceptionTuple, first: Callable[[], _T1]) -> Iterator[_T | _T1]:
|
||||
...
|
||||
|
||||
|
||||
def iter_except(
|
||||
func: Callable[[], object], exception: _ExceptionOrExceptionTuple, first: Callable[[], object] | None = None
|
||||
) -> Iterator[object]:
|
||||
"""Call a function repeatedly until an exception is raised.
|
||||
Converts a call-until-exception interface to an iterator interface.
|
||||
Like builtins.iter(func, sentinel) but uses an exception instead
|
||||
of a sentinel to end the loop.
|
||||
Examples:
|
||||
iter_except(functools.partial(heappop, h), IndexError) # priority queue iterator
|
||||
iter_except(d.popitem, KeyError) # non-blocking dict iterator
|
||||
iter_except(d.popleft, IndexError) # non-blocking deque iterator
|
||||
iter_except(q.get_nowait, Queue.Empty) # loop over a producer Queue
|
||||
iter_except(s.pop, KeyError) # non-blocking set iterator
|
||||
"""
|
||||
try:
|
||||
if first is not None:
|
||||
yield first() # For database APIs needing an initial cast to db.first()
|
||||
while True:
|
||||
yield func()
|
||||
except exception:
|
||||
pass
|
||||
|
||||
|
||||
def sliding_window(iterable: Iterable[_T], n: int) -> Iterator[tuple[_T, ...]]:
|
||||
# sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG
|
||||
it = iter(iterable)
|
||||
window = collections.deque(islice(it, n - 1), maxlen=n)
|
||||
for x in it:
|
||||
window.append(x)
|
||||
yield tuple(window)
|
||||
|
||||
|
||||
def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]:
|
||||
"roundrobin('ABC', 'D', 'EF') --> A D E B F C"
|
||||
# Recipe credited to George Sakkis
|
||||
num_active = len(iterables)
|
||||
nexts: Iterator[Callable[[], _T]] = cycle(iter(it).__next__ for it in iterables)
|
||||
while num_active:
|
||||
try:
|
||||
for next in nexts:
|
||||
yield next()
|
||||
except StopIteration:
|
||||
# Remove the iterator we just exhausted from the cycle.
|
||||
num_active -= 1
|
||||
nexts = cycle(islice(nexts, num_active))
|
||||
|
||||
|
||||
def partition(pred: Callable[[_T], bool], iterable: Iterable[_T]) -> tuple[Iterator[_T], Iterator[_T]]:
|
||||
"""Partition entries into false entries and true entries.
|
||||
If *pred* is slow, consider wrapping it with functools.lru_cache().
|
||||
"""
|
||||
# partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
|
||||
t1, t2 = tee(iterable)
|
||||
return filterfalse(pred, t1), filter(pred, t2)
|
||||
|
||||
|
||||
def subslices(seq: Sequence[_T]) -> Iterator[Sequence[_T]]:
|
||||
"Return all contiguous non-empty subslices of a sequence"
|
||||
# subslices('ABCD') --> A AB ABC ABCD B BC BCD C CD D
|
||||
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
|
||||
return map(operator.getitem, repeat(seq), slices)
|
||||
|
||||
|
||||
def before_and_after(predicate: Callable[[_T], bool], it: Iterable[_T]) -> tuple[Iterator[_T], Iterator[_T]]:
|
||||
"""Variant of takewhile() that allows complete
|
||||
access to the remainder of the iterator.
|
||||
>>> it = iter('ABCdEfGhI')
|
||||
>>> all_upper, remainder = before_and_after(str.isupper, it)
|
||||
>>> ''.join(all_upper)
|
||||
'ABC'
|
||||
>>> ''.join(remainder) # takewhile() would lose the 'd'
|
||||
'dEfGhI'
|
||||
Note that the first iterator must be fully
|
||||
consumed before the second iterator can
|
||||
generate valid results.
|
||||
"""
|
||||
it = iter(it)
|
||||
transition: list[_T] = []
|
||||
|
||||
def true_iterator() -> Iterator[_T]:
|
||||
for elem in it:
|
||||
if predicate(elem):
|
||||
yield elem
|
||||
else:
|
||||
transition.append(elem)
|
||||
return
|
||||
|
||||
def remainder_iterator() -> Iterator[_T]:
|
||||
yield from transition
|
||||
yield from it
|
||||
|
||||
return true_iterator(), remainder_iterator()
|
||||
|
||||
|
||||
@overload
|
||||
def unique_everseen(iterable: Iterable[_HashableT], key: None = None) -> Iterator[_HashableT]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def unique_everseen(iterable: Iterable[_T], key: Callable[[_T], Hashable]) -> Iterator[_T]:
|
||||
...
|
||||
|
||||
|
||||
def unique_everseen(iterable: Iterable[_T], key: Callable[[_T], Hashable] | None = None) -> Iterator[_T]:
|
||||
"List unique elements, preserving order. Remember all elements ever seen."
|
||||
# unique_everseen('AAAABBBCCDAABBB') --> A B C D
|
||||
# unique_everseen('ABBcCAD', str.lower) --> A B c D
|
||||
seen: set[Hashable] = set()
|
||||
if key is None:
|
||||
for element in filterfalse(seen.__contains__, iterable):
|
||||
seen.add(element)
|
||||
yield element
|
||||
# For order preserving deduplication,
|
||||
# a faster but non-lazy solution is:
|
||||
# yield from dict.fromkeys(iterable)
|
||||
else:
|
||||
for element in iterable:
|
||||
k = key(element)
|
||||
if k not in seen:
|
||||
seen.add(k)
|
||||
yield element
|
||||
# For use cases that allow the last matching element to be returned,
|
||||
# a faster but non-lazy solution is:
|
||||
# t1, t2 = tee(iterable)
|
||||
# yield from dict(zip(map(key, t1), t2)).values()
|
||||
|
||||
|
||||
def powerset(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]:
|
||||
"powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
|
||||
s = list(iterable)
|
||||
return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
|
||||
|
||||
|
||||
def polynomial_derivative(coefficients: Sequence[float]) -> list[float]:
|
||||
"""Compute the first derivative of a polynomial.
|
||||
f(x) = x³ -4x² -17x + 60
|
||||
f'(x) = 3x² -8x -17
|
||||
"""
|
||||
# polynomial_derivative([1, -4, -17, 60]) -> [3, -8, -17]
|
||||
n = len(coefficients)
|
||||
powers = reversed(range(1, n))
|
||||
return list(map(operator.mul, coefficients, powers))
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
|
||||
@overload
|
||||
def grouper(
|
||||
iterable: Iterable[_T], n: int, *, incomplete: Literal["fill"] = "fill", fillvalue: None = None
|
||||
) -> Iterator[tuple[_T | None, ...]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def grouper(
|
||||
iterable: Iterable[_T], n: int, *, incomplete: Literal["fill"] = "fill", fillvalue: _T1
|
||||
) -> Iterator[tuple[_T | _T1, ...]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def grouper(
|
||||
iterable: Iterable[_T], n: int, *, incomplete: Literal["strict", "ignore"], fillvalue: None = None
|
||||
) -> Iterator[tuple[_T, ...]]:
|
||||
...
|
||||
|
||||
def grouper(
|
||||
iterable: Iterable[object], n: int, *, incomplete: Literal["fill", "strict", "ignore"] = "fill", fillvalue: object = None
|
||||
) -> Iterator[tuple[object, ...]]:
|
||||
"Collect data into non-overlapping fixed-length chunks or blocks"
|
||||
# grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
|
||||
# grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError
|
||||
# grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
|
||||
args = [iter(iterable)] * n
|
||||
if incomplete == "fill":
|
||||
return zip_longest(*args, fillvalue=fillvalue)
|
||||
if incomplete == "strict":
|
||||
return zip(*args, strict=True)
|
||||
if incomplete == "ignore":
|
||||
return zip(*args)
|
||||
else:
|
||||
raise ValueError("Expected fill, strict, or ignore")
|
||||
|
||||
def transpose(it: Iterable[Iterable[_T]]) -> Iterator[tuple[_T, ...]]:
|
||||
"Swap the rows and columns of the input."
|
||||
# transpose([(1, 2, 3), (11, 22, 33)]) --> (1, 11) (2, 22) (3, 33)
|
||||
return zip(*it, strict=True)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
|
||||
def sum_of_squares(it: Iterable[float]) -> float:
|
||||
"Add up the squares of the input values."
|
||||
# sum_of_squares([10, 20, 30]) -> 1400
|
||||
return math.sumprod(*tee(it))
|
||||
|
||||
def convolve(signal: Iterable[float], kernel: Iterable[float]) -> Iterator[float]:
|
||||
"""Discrete linear convolution of two iterables.
|
||||
The kernel is fully consumed before the calculations begin.
|
||||
The signal is consumed lazily and can be infinite.
|
||||
Convolutions are mathematically commutative.
|
||||
If the signal and kernel are swapped,
|
||||
the output will be the same.
|
||||
Article: https://betterexplained.com/articles/intuitive-convolution/
|
||||
Video: https://www.youtube.com/watch?v=KuXjwB4LzSA
|
||||
"""
|
||||
# convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur)
|
||||
# convolve(data, [1/2, 0, -1/2]) --> 1st derivative estimate
|
||||
# convolve(data, [1, -2, 1]) --> 2nd derivative estimate
|
||||
kernel = tuple(kernel)[::-1]
|
||||
n = len(kernel)
|
||||
padded_signal = chain(repeat(0, n - 1), signal, repeat(0, n - 1))
|
||||
windowed_signal = sliding_window(padded_signal, n)
|
||||
return map(math.sumprod, repeat(kernel), windowed_signal)
|
||||
|
||||
def polynomial_eval(coefficients: Sequence[float], x: float) -> float:
|
||||
"""Evaluate a polynomial at a specific value.
|
||||
Computes with better numeric stability than Horner's method.
|
||||
"""
|
||||
# Evaluate x³ -4x² -17x + 60 at x = 2.5
|
||||
# polynomial_eval([1, -4, -17, 60], x=2.5) --> 8.125
|
||||
n = len(coefficients)
|
||||
if not n:
|
||||
return type(x)(0)
|
||||
powers = map(pow, repeat(x), reversed(range(n)))
|
||||
return math.sumprod(coefficients, powers)
|
||||
Reference in New Issue
Block a user