Make multiprocessing pipes generic (#11137)

This commit is contained in:
Avasam
2024-10-01 21:11:42 -04:00
committed by GitHub
parent 44aa63330b
commit bdb5b52d50
6 changed files with 72 additions and 27 deletions

View File

@@ -19,8 +19,9 @@ _global_shutdown: bool
class _ThreadWakeup:
_closed: bool
_reader: Connection
_writer: Connection
# Any: Unused send and recv methods
_reader: Connection[Any, Any]
_writer: Connection[Any, Any]
def close(self) -> None: ...
def wakeup(self) -> None: ...
def clear(self) -> None: ...

View File

@@ -1,9 +1,9 @@
import socket
import sys
import types
from _typeshed import ReadableBuffer
from _typeshed import Incomplete, ReadableBuffer
from collections.abc import Iterable
from typing import Any, SupportsIndex
from types import TracebackType
from typing import Any, Generic, SupportsIndex, TypeVar
from typing_extensions import Self, TypeAlias
__all__ = ["Client", "Listener", "Pipe", "wait"]
@@ -11,7 +11,11 @@ __all__ = ["Client", "Listener", "Pipe", "wait"]
# https://docs.python.org/3/library/multiprocessing.html#address-formats
_Address: TypeAlias = str | tuple[str, int]
class _ConnectionBase:
# Defaulting to Any to avoid forcing generics on a lot of pre-existing code
_SendT = TypeVar("_SendT", contravariant=True, default=Any)
_RecvT = TypeVar("_RecvT", covariant=True, default=Any)
class _ConnectionBase(Generic[_SendT, _RecvT]):
def __init__(self, handle: SupportsIndex, readable: bool = True, writable: bool = True) -> None: ...
@property
def closed(self) -> bool: ... # undocumented
@@ -22,27 +26,27 @@ class _ConnectionBase:
def fileno(self) -> int: ...
def close(self) -> None: ...
def send_bytes(self, buf: ReadableBuffer, offset: int = 0, size: int | None = None) -> None: ...
def send(self, obj: Any) -> None: ...
def send(self, obj: _SendT) -> None: ...
def recv_bytes(self, maxlength: int | None = None) -> bytes: ...
def recv_bytes_into(self, buf: Any, offset: int = 0) -> int: ...
def recv(self) -> Any: ...
def recv(self) -> _RecvT: ...
def poll(self, timeout: float | None = 0.0) -> bool: ...
def __enter__(self) -> Self: ...
def __exit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: types.TracebackType | None
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None
) -> None: ...
def __del__(self) -> None: ...
class Connection(_ConnectionBase): ...
class Connection(_ConnectionBase[_SendT, _RecvT]): ...
if sys.platform == "win32":
class PipeConnection(_ConnectionBase): ...
class PipeConnection(_ConnectionBase[_SendT, _RecvT]): ...
class Listener:
def __init__(
self, address: _Address | None = None, family: str | None = None, backlog: int = 1, authkey: bytes | None = None
) -> None: ...
def accept(self) -> Connection: ...
def accept(self) -> Connection[Incomplete, Incomplete]: ...
def close(self) -> None: ...
@property
def address(self) -> _Address: ...
@@ -50,26 +54,30 @@ class Listener:
def last_accepted(self) -> _Address | None: ...
def __enter__(self) -> Self: ...
def __exit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: types.TracebackType | None
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_tb: TracebackType | None
) -> None: ...
# Any: send and recv methods unused
if sys.version_info >= (3, 12):
def deliver_challenge(connection: Connection, authkey: bytes, digest_name: str = "sha256") -> None: ...
def deliver_challenge(connection: Connection[Any, Any], authkey: bytes, digest_name: str = "sha256") -> None: ...
else:
def deliver_challenge(connection: Connection, authkey: bytes) -> None: ...
def deliver_challenge(connection: Connection[Any, Any], authkey: bytes) -> None: ...
def answer_challenge(connection: Connection, authkey: bytes) -> None: ...
def answer_challenge(connection: Connection[Any, Any], authkey: bytes) -> None: ...
def wait(
object_list: Iterable[Connection | socket.socket | int], timeout: float | None = None
) -> list[Connection | socket.socket | int]: ...
def Client(address: _Address, family: str | None = None, authkey: bytes | None = None) -> Connection: ...
object_list: Iterable[Connection[_SendT, _RecvT] | socket.socket | int], timeout: float | None = None
) -> list[Connection[_SendT, _RecvT] | socket.socket | int]: ...
def Client(address: _Address, family: str | None = None, authkey: bytes | None = None) -> Connection[Any, Any]: ...
# N.B. Keep this in sync with multiprocessing.context.BaseContext.Pipe.
# _ConnectionBase is the common base class of Connection and PipeConnection
# and can be used in cross-platform code.
#
# The two connections should have the same generic types but inverted (Connection[_T1, _T2], Connection[_T2, _T1]).
# However, TypeVars scoped entirely within a return annotation is unspecified in the spec.
if sys.platform != "win32":
def Pipe(duplex: bool = True) -> tuple[Connection, Connection]: ...
def Pipe(duplex: bool = True) -> tuple[Connection[Any, Any], Connection[Any, Any]]: ...
else:
def Pipe(duplex: bool = True) -> tuple[PipeConnection, PipeConnection]: ...
def Pipe(duplex: bool = True) -> tuple[PipeConnection[Any, Any], PipeConnection[Any, Any]]: ...

View File

@@ -47,10 +47,13 @@ class BaseContext:
# N.B. Keep this in sync with multiprocessing.connection.Pipe.
# _ConnectionBase is the common base class of Connection and PipeConnection
# and can be used in cross-platform code.
#
# The two connections should have the same generic types but inverted (Connection[_T1, _T2], Connection[_T2, _T1]).
# However, TypeVars scoped entirely within a return annotation is unspecified in the spec.
if sys.platform != "win32":
def Pipe(self, duplex: bool = True) -> tuple[Connection, Connection]: ...
def Pipe(self, duplex: bool = True) -> tuple[Connection[Any, Any], Connection[Any, Any]]: ...
else:
def Pipe(self, duplex: bool = True) -> tuple[PipeConnection, PipeConnection]: ...
def Pipe(self, duplex: bool = True) -> tuple[PipeConnection[Any, Any], PipeConnection[Any, Any]]: ...
def Barrier(
self, parties: int, action: Callable[..., object] | None = None, timeout: float | None = None

View File

@@ -1,7 +1,7 @@
import queue
import sys
import threading
from _typeshed import SupportsKeysAndGetItem, SupportsRichComparison, SupportsRichComparisonT
from _typeshed import Incomplete, SupportsKeysAndGetItem, SupportsRichComparison, SupportsRichComparisonT
from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping, MutableSequence, Sequence
from types import TracebackType
from typing import Any, AnyStr, ClassVar, Generic, SupportsIndex, TypeVar, overload
@@ -129,7 +129,9 @@ class Server:
self, registry: dict[str, tuple[Callable[..., Any], Any, Any, Any]], address: Any, authkey: bytes, serializer: str
) -> None: ...
def serve_forever(self) -> None: ...
def accept_connection(self, c: Connection, name: str) -> None: ...
def accept_connection(
self, c: Connection[tuple[str, str | None], tuple[str, str, Iterable[Incomplete], Mapping[str, Incomplete]]], name: str
) -> None: ...
class BaseManager:
if sys.version_info >= (3, 11):

View File

@@ -35,8 +35,8 @@ if sys.platform == "win32":
handle: int, target_process: int | None = None, inheritable: bool = False, *, source_process: int | None = None
) -> int: ...
def steal_handle(source_pid: int, handle: int) -> int: ...
def send_handle(conn: connection.PipeConnection, handle: int, destination_pid: int) -> None: ...
def recv_handle(conn: connection.PipeConnection) -> int: ...
def send_handle(conn: connection.PipeConnection[DupHandle, Any], handle: int, destination_pid: int) -> None: ...
def recv_handle(conn: connection.PipeConnection[Any, DupHandle]) -> int: ...
class DupHandle:
def __init__(self, handle: int, access: int, pid: int | None = None) -> None: ...

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
import sys
from multiprocessing.connection import Pipe
if sys.platform != "win32":
from multiprocessing.connection import Connection
else:
from multiprocessing.connection import PipeConnection as Connection
# Unfortunately, we cannot validate that both connections have the same, but inverted generic types,
# since TypeVars scoped entirely within a return annotation is unspecified in the spec.
# Pipe[str, int]() -> tuple[Connection[str, int], Connection[int, str]]
a: Connection[str, int]
b: Connection[int, str]
a, b = Pipe()
connections: tuple[Connection[str, int], Connection[int, str]] = Pipe()
a, b = connections
a.send("test")
a.send(0) # type: ignore
test1: str = b.recv()
test2: int = b.recv() # type: ignore
b.send("test") # type: ignore
b.send(0)
test3: str = a.recv() # type: ignore
test4: int = a.recv()