Add SupportsRichComparison type to _typeshed (#6583)

Use it to improve types of `max()` and other functions.

Also make some other tweaks to types related to comparison dunders.

Fixes #6575
This commit is contained in:
Alex Waygood
2021-12-14 14:12:23 +00:00
committed by GitHub
parent 968fd6d01d
commit 5670ca2f75
9 changed files with 73 additions and 69 deletions

View File

@@ -1,21 +1,31 @@
import sys
from _typeshed import SupportsLessThan
from _typeshed import SupportsRichComparison
from typing import Callable, MutableSequence, Sequence, TypeVar
_T = TypeVar("_T")
if sys.version_info >= (3, 10):
def bisect_left(
a: Sequence[_T], x: _T, lo: int = ..., hi: int | None = ..., *, key: Callable[[_T], SupportsLessThan] | None = ...
a: Sequence[_T], x: _T, lo: int = ..., hi: int | None = ..., *, key: Callable[[_T], SupportsRichComparison] | None = ...
) -> int: ...
def bisect_right(
a: Sequence[_T], x: _T, lo: int = ..., hi: int | None = ..., *, key: Callable[[_T], SupportsLessThan] | None = ...
a: Sequence[_T], x: _T, lo: int = ..., hi: int | None = ..., *, key: Callable[[_T], SupportsRichComparison] | None = ...
) -> int: ...
def insort_left(
a: MutableSequence[_T], x: _T, lo: int = ..., hi: int | None = ..., *, key: Callable[[_T], SupportsLessThan] | None = ...
a: MutableSequence[_T],
x: _T,
lo: int = ...,
hi: int | None = ...,
*,
key: Callable[[_T], SupportsRichComparison] | None = ...,
) -> None: ...
def insort_right(
a: MutableSequence[_T], x: _T, lo: int = ..., hi: int | None = ..., *, key: Callable[[_T], SupportsLessThan] | None = ...
a: MutableSequence[_T],
x: _T,
lo: int = ...,
hi: int | None = ...,
*,
key: Callable[[_T], SupportsRichComparison] | None = ...,
) -> None: ...
else:

View File

@@ -1,4 +1,5 @@
import sys
from _typeshed import SupportsAnyComparison
from typing import (
Any,
AnyStr,
@@ -14,7 +15,6 @@ from typing import (
SupportsAbs,
Tuple,
TypeVar,
Union,
overload,
)
from typing_extensions import ParamSpec, SupportsIndex, final
@@ -35,29 +35,13 @@ class _SupportsNeg(Protocol[_T_co]):
class _SupportsPos(Protocol[_T_co]):
def __pos__(self) -> _T_co: ...
# Different to _typeshed.SupportsLessThan
class _SupportsLT(Protocol):
def __lt__(self, __other: Any) -> Any: ...
class _SupportsGT(Protocol):
def __gt__(self, __other: Any) -> Any: ...
class _SupportsLE(Protocol):
def __le__(self, __other: Any) -> Any: ...
class _SupportsGE(Protocol):
def __ge__(self, __other: Any) -> Any: ...
# We get false-positive errors if e.g. `lt` does not have the same signature as `le`,
# so a broad union type is required for all four comparison methods
_SupportsComparison = Union[_SupportsLE, _SupportsGE, _SupportsGT, _SupportsLT]
def lt(__a: _SupportsComparison, __b: _SupportsComparison) -> Any: ...
def le(__a: _SupportsComparison, __b: _SupportsComparison) -> Any: ...
# All four comparison functions must have the same signature, or we get false-positive errors
def lt(__a: SupportsAnyComparison, __b: SupportsAnyComparison) -> Any: ...
def le(__a: SupportsAnyComparison, __b: SupportsAnyComparison) -> Any: ...
def eq(__a: object, __b: object) -> Any: ...
def ne(__a: object, __b: object) -> Any: ...
def ge(__a: _SupportsComparison, __b: _SupportsComparison) -> Any: ...
def gt(__a: _SupportsComparison, __b: _SupportsComparison) -> Any: ...
def ge(__a: SupportsAnyComparison, __b: SupportsAnyComparison) -> Any: ...
def gt(__a: SupportsAnyComparison, __b: SupportsAnyComparison) -> Any: ...
def not_(__a: object) -> bool: ...
def truth(__a: object) -> bool: ...
def is_(__a: object, __b: object) -> bool: ...

View File

@@ -35,15 +35,25 @@ class SupportsNext(Protocol[_T_co]):
class SupportsAnext(Protocol[_T_co]):
def __anext__(self) -> Awaitable[_T_co]: ...
class SupportsLessThan(Protocol):
def __lt__(self, __other: Any) -> bool: ...
# Comparison protocols
SupportsLessThanT = TypeVar("SupportsLessThanT", bound=SupportsLessThan) # noqa: Y001
class SupportsDunderLT(Protocol):
def __lt__(self, __other: Any) -> Any: ...
class SupportsGreaterThan(Protocol):
def __gt__(self, __other: Any) -> bool: ...
class SupportsDunderGT(Protocol):
def __gt__(self, __other: Any) -> Any: ...
SupportsGreaterThanT = TypeVar("SupportsGreaterThanT", bound=SupportsGreaterThan) # noqa: Y001
class SupportsDunderLE(Protocol):
def __le__(self, __other: Any) -> Any: ...
class SupportsDunderGE(Protocol):
def __ge__(self, __other: Any) -> Any: ...
class SupportsAllComparisons(SupportsDunderLT, SupportsDunderGT, SupportsDunderLE, SupportsDunderGE, Protocol): ...
SupportsRichComparison = Union[SupportsDunderLT, SupportsDunderGT]
SupportsRichComparisonT = TypeVar("SupportsRichComparisonT", bound=SupportsRichComparison) # noqa: Y001
SupportsAnyComparison = Union[SupportsDunderLE, SupportsDunderGE, SupportsDunderGT, SupportsDunderLT]
class SupportsDivMod(Protocol[_T_contra, _T_co]):
def __divmod__(self, __other: _T_contra) -> _T_co: ...

View File

@@ -13,14 +13,12 @@ from _typeshed import (
StrOrBytesPath,
SupportsAnext,
SupportsDivMod,
SupportsGreaterThan,
SupportsGreaterThanT,
SupportsKeysAndGetItem,
SupportsLenAndGetItem,
SupportsLessThan,
SupportsLessThanT,
SupportsNext,
SupportsRDivMod,
SupportsRichComparison,
SupportsRichComparisonT,
SupportsTrunc,
SupportsWrite,
)
@@ -783,9 +781,9 @@ class list(MutableSequence[_T], Generic[_T]):
def remove(self, __value: _T) -> None: ...
# Signature of `list.sort` should be kept inline with `collections.UserList.sort()`
@overload
def sort(self: list[SupportsLessThanT], *, key: None = ..., reverse: bool = ...) -> None: ...
def sort(self: list[SupportsRichComparisonT], *, key: None = ..., reverse: bool = ...) -> None: ...
@overload
def sort(self, *, key: Callable[[_T], SupportsLessThan], reverse: bool = ...) -> None: ...
def sort(self, *, key: Callable[[_T], SupportsRichComparison], reverse: bool = ...) -> None: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T]: ...
def __str__(self) -> str: ...
@@ -1143,32 +1141,32 @@ class map(Iterator[_S], Generic[_S]):
@overload
def max(
__arg1: SupportsGreaterThanT, __arg2: SupportsGreaterThanT, *_args: SupportsGreaterThanT, key: None = ...
) -> SupportsGreaterThanT: ...
__arg1: SupportsRichComparisonT, __arg2: SupportsRichComparisonT, *_args: SupportsRichComparisonT, key: None = ...
) -> SupportsRichComparisonT: ...
@overload
def max(__arg1: _T, __arg2: _T, *_args: _T, key: Callable[[_T], SupportsGreaterThan]) -> _T: ...
def max(__arg1: _T, __arg2: _T, *_args: _T, key: Callable[[_T], SupportsRichComparisonT]) -> _T: ...
@overload
def max(__iterable: Iterable[SupportsGreaterThanT], *, key: None = ...) -> SupportsGreaterThanT: ...
def max(__iterable: Iterable[SupportsRichComparisonT], *, key: None = ...) -> SupportsRichComparisonT: ...
@overload
def max(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsGreaterThan]) -> _T: ...
def max(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsRichComparison]) -> _T: ...
@overload
def max(__iterable: Iterable[SupportsGreaterThanT], *, key: None = ..., default: _T) -> SupportsGreaterThanT | _T: ...
def max(__iterable: Iterable[SupportsRichComparisonT], *, key: None = ..., default: _T) -> SupportsRichComparisonT | _T: ...
@overload
def max(__iterable: Iterable[_T1], *, key: Callable[[_T1], SupportsGreaterThan], default: _T2) -> _T1 | _T2: ...
def max(__iterable: Iterable[_T1], *, key: Callable[[_T1], SupportsRichComparison], default: _T2) -> _T1 | _T2: ...
@overload
def min(
__arg1: SupportsLessThanT, __arg2: SupportsLessThanT, *_args: SupportsLessThanT, key: None = ...
) -> SupportsLessThanT: ...
__arg1: SupportsRichComparisonT, __arg2: SupportsRichComparisonT, *_args: SupportsRichComparisonT, key: None = ...
) -> SupportsRichComparisonT: ...
@overload
def min(__arg1: _T, __arg2: _T, *_args: _T, key: Callable[[_T], SupportsLessThan]) -> _T: ...
def min(__arg1: _T, __arg2: _T, *_args: _T, key: Callable[[_T], SupportsRichComparison]) -> _T: ...
@overload
def min(__iterable: Iterable[SupportsLessThanT], *, key: None = ...) -> SupportsLessThanT: ...
def min(__iterable: Iterable[SupportsRichComparisonT], *, key: None = ...) -> SupportsRichComparisonT: ...
@overload
def min(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsLessThan]) -> _T: ...
def min(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsRichComparison]) -> _T: ...
@overload
def min(__iterable: Iterable[SupportsLessThanT], *, key: None = ..., default: _T) -> SupportsLessThanT | _T: ...
def min(__iterable: Iterable[SupportsRichComparisonT], *, key: None = ..., default: _T) -> SupportsRichComparisonT | _T: ...
@overload
def min(__iterable: Iterable[_T1], *, key: Callable[[_T1], SupportsLessThan], default: _T2) -> _T1 | _T2: ...
def min(__iterable: Iterable[_T1], *, key: Callable[[_T1], SupportsRichComparison], default: _T2) -> _T1 | _T2: ...
@overload
def next(__i: SupportsNext[_T]) -> _T: ...
@overload
@@ -1379,9 +1377,11 @@ def round(number: SupportsRound[_T], ndigits: SupportsIndex) -> _T: ...
# for why arg 3 of `setattr` should be annotated with `Any` and not `object`
def setattr(__obj: object, __name: str, __value: Any) -> None: ...
@overload
def sorted(__iterable: Iterable[SupportsLessThanT], *, key: None = ..., reverse: bool = ...) -> list[SupportsLessThanT]: ...
def sorted(
__iterable: Iterable[SupportsRichComparisonT], *, key: None = ..., reverse: bool = ...
) -> list[SupportsRichComparisonT]: ...
@overload
def sorted(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsLessThan], reverse: bool = ...) -> list[_T]: ...
def sorted(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsRichComparison], reverse: bool = ...) -> list[_T]: ...
if sys.version_info >= (3, 8):
@overload

View File

@@ -1,6 +1,6 @@
import sys
from _collections_abc import dict_items, dict_keys, dict_values
from _typeshed import Self, SupportsKeysAndGetItem, SupportsLessThan, SupportsLessThanT
from _typeshed import Self, SupportsKeysAndGetItem, SupportsRichComparison, SupportsRichComparisonT
from typing import Any, Dict, Generic, NoReturn, Tuple, Type, TypeVar, overload
from typing_extensions import SupportsIndex, final
@@ -99,9 +99,9 @@ class UserList(MutableSequence[_T]):
def reverse(self) -> None: ...
# All arguments are passed to `list.sort` at runtime, so the signature should be kept in line with `list.sort`.
@overload
def sort(self: UserList[SupportsLessThanT], *, key: None = ..., reverse: bool = ...) -> None: ...
def sort(self: UserList[SupportsRichComparisonT], *, key: None = ..., reverse: bool = ...) -> None: ...
@overload
def sort(self, *, key: Callable[[_T], SupportsLessThan], reverse: bool = ...) -> None: ...
def sort(self, *, key: Callable[[_T], SupportsRichComparison], reverse: bool = ...) -> None: ...
def extend(self, other: Iterable[_T]) -> None: ...
_UserStringT = TypeVar("_UserStringT", bound=UserString)

View File

@@ -1,6 +1,6 @@
import sys
import types
from _typeshed import SupportsItems, SupportsLessThan
from _typeshed import SupportsAllComparisons, SupportsItems
from typing import Any, Callable, Generic, Hashable, Iterable, NamedTuple, Sequence, Sized, Tuple, Type, TypeVar, overload
from typing_extensions import final
@@ -45,7 +45,7 @@ WRAPPER_UPDATES: Sequence[str]
def update_wrapper(wrapper: _T, wrapped: _AnyCallable, assigned: Sequence[str] = ..., updated: Sequence[str] = ...) -> _T: ...
def wraps(wrapped: _AnyCallable, assigned: Sequence[str] = ..., updated: Sequence[str] = ...) -> Callable[[_T], _T]: ...
def total_ordering(cls: Type[_T]) -> Type[_T]: ...
def cmp_to_key(mycmp: Callable[[_T, _T], int]) -> Callable[[_T], SupportsLessThan]: ...
def cmp_to_key(mycmp: Callable[[_T, _T], int]) -> Callable[[_T], SupportsAllComparisons]: ...
class partial(Generic[_T]):
func: Callable[..., _T]

View File

@@ -1,5 +1,5 @@
import os
from _typeshed import BytesPath, StrOrBytesPath, StrPath, SupportsLessThanT
from _typeshed import BytesPath, StrOrBytesPath, StrPath, SupportsRichComparisonT
from typing import Sequence, Tuple, overload
from typing_extensions import Literal
@@ -11,9 +11,9 @@ def commonprefix(m: Sequence[StrPath]) -> str: ...
@overload
def commonprefix(m: Sequence[BytesPath]) -> bytes | Literal[""]: ...
@overload
def commonprefix(m: Sequence[list[SupportsLessThanT]]) -> Sequence[SupportsLessThanT]: ...
def commonprefix(m: Sequence[list[SupportsRichComparisonT]]) -> Sequence[SupportsRichComparisonT]: ...
@overload
def commonprefix(m: Sequence[Tuple[SupportsLessThanT, ...]]) -> Sequence[SupportsLessThanT]: ...
def commonprefix(m: Sequence[Tuple[SupportsRichComparisonT, ...]]) -> Sequence[SupportsRichComparisonT]: ...
def exists(path: StrOrBytesPath | int) -> bool: ...
def getsize(filename: StrOrBytesPath | int) -> int: ...
def isfile(path: StrOrBytesPath | int) -> bool: ...

View File

@@ -1,4 +1,4 @@
from _typeshed import SupportsLessThan
from _typeshed import SupportsRichComparison
from typing import Any, Callable, Iterable, TypeVar
_T = TypeVar("_T")
@@ -9,6 +9,6 @@ def heappushpop(__heap: list[_T], __item: _T) -> _T: ...
def heapify(__heap: list[Any]) -> None: ...
def heapreplace(__heap: list[_T], __item: _T) -> _T: ...
def merge(*iterables: Iterable[_T], key: Callable[[_T], Any] | None = ..., reverse: bool = ...) -> Iterable[_T]: ...
def nlargest(n: int, iterable: Iterable[_T], key: Callable[[_T], SupportsLessThan] | None = ...) -> list[_T]: ...
def nsmallest(n: int, iterable: Iterable[_T], key: Callable[[_T], SupportsLessThan] | None = ...) -> list[_T]: ...
def nlargest(n: int, iterable: Iterable[_T], key: Callable[[_T], SupportsRichComparison] | None = ...) -> list[_T]: ...
def nsmallest(n: int, iterable: Iterable[_T], key: Callable[[_T], SupportsRichComparison] | None = ...) -> list[_T]: ...
def _heapify_max(__x: list[Any]) -> None: ... # undocumented

View File

@@ -1,5 +1,5 @@
import sys
from _typeshed import SupportsLessThanT
from _typeshed import SupportsRichComparisonT
from decimal import Decimal
from fractions import Fraction
from typing import Any, Hashable, Iterable, NamedTuple, Sequence, SupportsFloat, Type, TypeVar, Union
@@ -27,8 +27,8 @@ else:
def harmonic_mean(data: Iterable[_NumberT]) -> _NumberT: ...
def median(data: Iterable[_NumberT]) -> _NumberT: ...
def median_low(data: Iterable[SupportsLessThanT]) -> SupportsLessThanT: ...
def median_high(data: Iterable[SupportsLessThanT]) -> SupportsLessThanT: ...
def median_low(data: Iterable[SupportsRichComparisonT]) -> SupportsRichComparisonT: ...
def median_high(data: Iterable[SupportsRichComparisonT]) -> SupportsRichComparisonT: ...
def median_grouped(data: Iterable[_NumberT], interval: _NumberT = ...) -> _NumberT: ...
def mode(data: Iterable[_HashableT]) -> _HashableT: ...