From 275fc28733f8e74a55e70f279234119ab50a7c17 Mon Sep 17 00:00:00 2001 From: Nikita Sobolev Date: Sat, 30 Jul 2022 14:50:39 +0300 Subject: [PATCH] Improve `redis.cluster` annotations (#8379) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alex Waygood --- stubs/redis/redis/cluster.pyi | 74 ++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/stubs/redis/redis/cluster.pyi b/stubs/redis/redis/cluster.pyi index 2e09af694..c32379de6 100644 --- a/stubs/redis/redis/cluster.pyi +++ b/stubs/redis/redis/cluster.pyi @@ -1,16 +1,18 @@ -from _typeshed import Incomplete +from _typeshed import Incomplete, Self from collections.abc import Callable from threading import Lock +from types import TracebackType from typing import Any, ClassVar, Generic +from typing_extensions import Literal -from redis.client import PubSub -from redis.commands import RedisClusterCommands +from redis.client import CaseInsensitiveDict, PubSub, Redis +from redis.commands import CommandsParser, RedisClusterCommands from redis.commands.core import _StrType -from redis.connection import BaseParser, Encoder +from redis.connection import BaseParser, Connection, Encoder from redis.exceptions import RedisError -def get_node_name(host, port): ... -def get_connection(redis_node, *args, **options): ... +def get_node_name(host: str, port: str | int) -> str: ... +def get_connection(redis_node: Redis[Any], *args, **options) -> Connection: ... def parse_scan_result(command, res, **options): ... def parse_pubsub_numsub(command, res, **options): ... def parse_cluster_slots(resp, **options): ... @@ -18,16 +20,15 @@ def parse_cluster_slots(resp, **options): ... PRIMARY: str REPLICA: str SLOT_ID: str -REDIS_ALLOWED_KEYS: Any -KWARGS_DISABLED_KEYS: Any -READ_COMMANDS: Any +REDIS_ALLOWED_KEYS: tuple[str, ...] +KWARGS_DISABLED_KEYS: tuple[str, ...] def cleanup_kwargs(**kwargs): ... # It uses `DefaultParser` in real life, but it is a dynamic base class. class ClusterParser(BaseParser): ... -class RedisCluster(RedisClusterCommands[_StrType], Generic[_StrType]): +class AbstractRedisCluster: RedisClusterRequestTTL: ClassVar[int] PRIMARIES: ClassVar[str] REPLICAS: ClassVar[str] @@ -35,41 +36,45 @@ class RedisCluster(RedisClusterCommands[_StrType], Generic[_StrType]): RANDOM: ClassVar[str] DEFAULT_NODE: ClassVar[str] NODE_FLAGS: ClassVar[set[str]] - COMMAND_FLAGS: ClassVar[Any] + COMMAND_FLAGS: ClassVar[dict[str, str]] CLUSTER_COMMANDS_RESPONSE_CALLBACKS: ClassVar[dict[str, Any]] - RESULT_CALLBACKS: ClassVar[Any] + RESULT_CALLBACKS: ClassVar[dict[str, Callable[[Incomplete, Incomplete], Incomplete]]] ERRORS_ALLOW_RETRY: ClassVar[tuple[type[RedisError], ...]] - user_on_connect_func: Any - encoder: Any - cluster_error_retry_attempts: Any - command_flags: Any - node_flags: Any - read_from_replicas: Any + +class RedisCluster(AbstractRedisCluster, RedisClusterCommands[_StrType], Generic[_StrType]): + user_on_connect_func: Callable[[Connection], object] | None + encoder: Encoder + cluster_error_retry_attempts: int + command_flags: dict[str, str] + node_flags: set[str] + read_from_replicas: bool reinitialize_counter: int reinitialize_steps: int - nodes_manager: Any - cluster_response_callbacks: Any - result_callbacks: Any - commands_parser: Any - def __init__( + nodes_manager: NodesManager + cluster_response_callbacks: CaseInsensitiveDict[str, Callable[..., Incomplete]] + result_callbacks: CaseInsensitiveDict[str, Callable[[Incomplete, Incomplete], Incomplete]] + commands_parser: CommandsParser + def __init__( # TODO: make @overloads, either `url` or `host:port` can be passed self, - host: Incomplete | None = ..., - port: int = ..., - startup_nodes: Incomplete | None = ..., + host: str | None = ..., + port: int | None = ..., + startup_nodes: list[ClusterNode] | None = ..., cluster_error_retry_attempts: int = ..., require_full_coverage: bool = ..., reinitialize_steps: int = ..., read_from_replicas: bool = ..., dynamic_startup_nodes: bool = ..., - url: Incomplete | None = ..., + url: str | None = ..., **kwargs, ) -> None: ... - def __enter__(self): ... - def __exit__(self, exc_type, exc_value, traceback) -> None: ... + def __enter__(self: Self) -> Self: ... + def __exit__( + self, type: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None + ) -> None: ... def __del__(self) -> None: ... def disconnect_connection_pools(self) -> None: ... @classmethod - def from_url(cls, url, **kwargs): ... + def from_url(cls: type[Self], url: str, **kwargs) -> Self: ... def on_connect(self, connection) -> None: ... def get_redis_connection(self, node): ... def get_node(self, host: Any | None = ..., port: Any | None = ..., node_name: Any | None = ...): ... @@ -184,12 +189,9 @@ class ClusterPipeline(RedisCluster[_StrType], Generic[_StrType]): lock: Lock | None = ..., **kwargs, ) -> None: ... - def __enter__(self): ... - def __exit__(self, exc_type, exc_value, traceback) -> None: ... - def __del__(self) -> None: ... - def __len__(self): ... - def __nonzero__(self): ... - def __bool__(self): ... + def __len__(self) -> int: ... + def __nonzero__(self) -> Literal[True]: ... + def __bool__(self) -> Literal[True]: ... def execute_command(self, *args, **kwargs): ... def pipeline_execute_command(self, *args, **options): ... def raise_first_error(self, stack) -> None: ...