Improve multiprocessing stubs (#8202)

Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
This commit is contained in:
Alex Waygood
2022-07-01 19:20:39 +01:00
committed by GitHub
parent 7b54854c90
commit a2e8346d9a
8 changed files with 56 additions and 90 deletions

View File

@@ -1,18 +1,12 @@
import sys
from collections.abc import Callable, Iterable
from logging import Logger
from multiprocessing import connection, context, pool, reduction as reducer, synchronize
from multiprocessing import context, reduction as reducer, synchronize
from multiprocessing.context import (
AuthenticationError as AuthenticationError,
BaseContext,
BufferTooShort as BufferTooShort,
DefaultContext,
Process as Process,
ProcessError as ProcessError,
SpawnContext,
TimeoutError as TimeoutError,
)
from multiprocessing.managers import SyncManager
from multiprocessing.process import active_children as active_children, current_process as current_process
# These are technically functions that return instances of these Queue classes.
@@ -20,15 +14,12 @@ from multiprocessing.process import active_children as active_children, current_
# multiprocessing.queues or the aliases defined below. See #4266 for discussion.
from multiprocessing.queues import JoinableQueue as JoinableQueue, Queue as Queue, SimpleQueue as SimpleQueue
from multiprocessing.spawn import freeze_support as freeze_support
from typing import Any, TypeVar, overload
from typing_extensions import Literal, TypeAlias
from typing import TypeVar
from typing_extensions import TypeAlias
if sys.version_info >= (3, 8):
from multiprocessing.process import parent_process as parent_process
if sys.platform != "win32":
from multiprocessing.context import ForkContext, ForkServerContext
__all__ = [
"Array",
"AuthenticationError",
@@ -92,60 +83,29 @@ _LockType: TypeAlias = synchronize.Lock
_RLockType: TypeAlias = synchronize.RLock
_SemaphoreType: TypeAlias = synchronize.Semaphore
# N.B. The functions below are generated at runtime by partially applying
# multiprocessing.context.BaseContext's methods, so the two signatures should
# be identical (modulo self).
# Synchronization primitives
_LockLike: TypeAlias = synchronize.Lock | synchronize.RLock
# These functions (really bound methods)
# are all autogenerated at runtime here: https://github.com/python/cpython/blob/600c65c094b0b48704d8ec2416930648052ba715/Lib/multiprocessing/__init__.py#L23
RawValue = context._default_context.RawValue
RawArray = context._default_context.RawArray
Value = context._default_context.Value
Array = context._default_context.Array
def Barrier(parties: int, action: Callable[..., Any] | None = ..., timeout: float | None = ...) -> _BarrierType: ...
def BoundedSemaphore(value: int = ...) -> _BoundedSemaphoreType: ...
def Condition(lock: _LockLike | None = ...) -> _ConditionType: ...
def Event() -> _EventType: ...
def Lock() -> _LockType: ...
def RLock() -> _RLockType: ...
def Semaphore(value: int = ...) -> _SemaphoreType: ...
def Pipe(duplex: bool = ...) -> tuple[connection.Connection, connection.Connection]: ...
def Pool(
processes: int | None = ...,
initializer: Callable[..., Any] | None = ...,
initargs: Iterable[Any] = ...,
maxtasksperchild: int | None = ...,
) -> pool.Pool: ...
# ----- multiprocessing function stubs -----
def allow_connection_pickling() -> None: ...
def cpu_count() -> int: ...
def get_logger() -> Logger: ...
def log_to_stderr(level: str | int | None = ...) -> Logger: ...
def Manager() -> SyncManager: ...
def set_executable(executable: str) -> None: ...
def set_forkserver_preload(module_names: list[str]) -> None: ...
def get_all_start_methods() -> list[str]: ...
def get_start_method(allow_none: bool = ...) -> str | None: ...
def set_start_method(method: str, force: bool | None = ...) -> None: ...
if sys.platform != "win32":
@overload
def get_context(method: None = ...) -> DefaultContext: ...
@overload
def get_context(method: Literal["spawn"]) -> SpawnContext: ...
@overload
def get_context(method: Literal["fork"]) -> ForkContext: ...
@overload
def get_context(method: Literal["forkserver"]) -> ForkServerContext: ...
@overload
def get_context(method: str) -> BaseContext: ...
else:
@overload
def get_context(method: None = ...) -> DefaultContext: ...
@overload
def get_context(method: Literal["spawn"]) -> SpawnContext: ...
@overload
def get_context(method: str) -> BaseContext: ...
Barrier = context._default_context.Barrier
BoundedSemaphore = context._default_context.BoundedSemaphore
Condition = context._default_context.Condition
Event = context._default_context.Event
Lock = context._default_context.Lock
RLock = context._default_context.RLock
Semaphore = context._default_context.Semaphore
Pipe = context._default_context.Pipe
Pool = context._default_context.Pool
allow_connection_pickling = context._default_context.allow_connection_pickling
cpu_count = context._default_context.cpu_count
get_logger = context._default_context.get_logger
log_to_stderr = context._default_context.log_to_stderr
Manager = context._default_context.Manager
set_executable = context._default_context.set_executable
set_forkserver_preload = context._default_context.set_forkserver_preload
get_all_start_methods = context._default_context.get_all_start_methods
get_start_method = context._default_context.get_start_method
set_start_method = context._default_context.set_start_method
get_context = context._default_context.get_context

View File

@@ -58,4 +58,4 @@ def wait(
object_list: Iterable[Connection | socket.socket | int], timeout: float | None = ...
) -> list[Connection | socket.socket | int]: ...
def Client(address: _Address, family: str | None = ..., authkey: bytes | None = ...) -> Connection: ...
def Pipe(duplex: bool = ...) -> tuple[Connection, Connection]: ...
def Pipe(duplex: bool = ...) -> tuple[_ConnectionBase, _ConnectionBase]: ...

View File

@@ -5,6 +5,8 @@ from collections.abc import Callable, Iterable, Sequence
from ctypes import _CData
from logging import Logger
from multiprocessing import queues, synchronize
from multiprocessing.connection import _ConnectionBase
from multiprocessing.managers import SyncManager
from multiprocessing.pool import Pool as _Pool
from multiprocessing.process import BaseProcess
from multiprocessing.sharedctypes import SynchronizedArray, SynchronizedBase
@@ -42,12 +44,10 @@ class BaseContext:
@staticmethod
def active_children() -> list[BaseProcess]: ...
def cpu_count(self) -> int: ...
# TODO: change return to SyncManager once a stub exists in multiprocessing.managers
def Manager(self) -> Any: ...
# TODO: change return to Pipe once a stub exists in multiprocessing.connection
def Pipe(self, duplex: bool = ...) -> Any: ...
def Manager(self) -> SyncManager: ...
def Pipe(self, duplex: bool = ...) -> tuple[_ConnectionBase, _ConnectionBase]: ...
def Barrier(
self, parties: int, action: Callable[..., Any] | None = ..., timeout: float | None = ...
self, parties: int, action: Callable[..., object] | None = ..., timeout: float | None = ...
) -> synchronize.Barrier: ...
def BoundedSemaphore(self, value: int = ...) -> synchronize.BoundedSemaphore: ...
def Condition(self, lock: _LockLike | None = ...) -> synchronize.Condition: ...
@@ -61,7 +61,7 @@ class BaseContext:
def Pool(
self,
processes: int | None = ...,
initializer: Callable[..., Any] | None = ...,
initializer: Callable[..., object] | None = ...,
initargs: Iterable[Any] = ...,
maxtasksperchild: int | None = ...,
) -> _Pool: ...
@@ -120,7 +120,10 @@ class BaseContext:
@overload
def get_context(self, method: str) -> BaseContext: ...
def get_start_method(self, allow_none: bool = ...) -> str: ...
@overload
def get_start_method(self, allow_none: Literal[False] = ...) -> str: ...
@overload
def get_start_method(self, allow_none: bool) -> str | None: ...
def set_start_method(self, method: str | None, force: bool = ...) -> None: ...
@property
def reducer(self) -> str: ...

View File

@@ -3,6 +3,15 @@ import threading
import weakref
from collections.abc import Callable, Iterable, Mapping, Sequence
from queue import Queue as Queue
from threading import (
Barrier as Barrier,
BoundedSemaphore as BoundedSemaphore,
Condition as Condition,
Event as Event,
Lock as Lock,
RLock as RLock,
Semaphore as Semaphore,
)
from typing import Any
from typing_extensions import Literal
@@ -28,13 +37,6 @@ __all__ = [
]
JoinableQueue = Queue
Barrier = threading.Barrier
BoundedSemaphore = threading.BoundedSemaphore
Condition = threading.Condition
Event = threading.Event
Lock = threading.Lock
RLock = threading.RLock
Semaphore = threading.Semaphore
class DummyProcess(threading.Thread):
_children: weakref.WeakKeyDictionary[Any, Any]
@@ -46,7 +48,7 @@ class DummyProcess(threading.Thread):
def __init__(
self,
group: Any = ...,
target: Callable[..., Any] | None = ...,
target: Callable[..., object] | None = ...,
name: str | None = ...,
args: Iterable[Any] = ...,
kwargs: Mapping[str, Any] = ...,
@@ -67,8 +69,10 @@ class Value:
def Array(typecode: Any, sequence: Sequence[Any], lock: Any = ...) -> array.array[Any]: ...
def Manager() -> Any: ...
def Pool(processes: int | None = ..., initializer: Callable[..., Any] | None = ..., initargs: Iterable[Any] = ...) -> Any: ...
def Pool(processes: int | None = ..., initializer: Callable[..., object] | None = ..., initargs: Iterable[Any] = ...) -> Any: ...
def active_children() -> list[Any]: ...
def current_process() -> threading.Thread: ...
current_process = threading.current_thread
def freeze_support() -> None: ...
def shutdown() -> None: ...

View File

@@ -148,7 +148,7 @@ class BaseManager:
def get_server(self) -> Server: ...
def connect(self) -> None: ...
def start(self, initializer: Callable[..., Any] | None = ..., initargs: Iterable[Any] = ...) -> None: ...
def start(self, initializer: Callable[..., object] | None = ..., initargs: Iterable[Any] = ...) -> None: ...
def shutdown(self) -> None: ... # only available after start() was called
def join(self, timeout: float | None = ...) -> None: ... # undocumented
@property
@@ -157,7 +157,7 @@ class BaseManager:
def register(
cls,
typeid: str,
callable: Callable[..., Any] | None = ...,
callable: Callable[..., object] | None = ...,
proxytype: Any = ...,
exposed: Sequence[str] | None = ...,
method_to_typeid: Mapping[str, str] | None = ...,

View File

@@ -72,7 +72,7 @@ class Pool:
def __init__(
self,
processes: int | None = ...,
initializer: Callable[..., None] | None = ...,
initializer: Callable[..., object] | None = ...,
initargs: Iterable[Any] = ...,
maxtasksperchild: int | None = ...,
context: Any | None = ...,
@@ -118,7 +118,7 @@ class Pool:
class ThreadPool(Pool):
def __init__(
self, processes: int | None = ..., initializer: Callable[..., Any] | None = ..., initargs: Iterable[Any] = ...
self, processes: int | None = ..., initializer: Callable[..., object] | None = ..., initargs: Iterable[Any] = ...
) -> None: ...
# undocumented

View File

@@ -15,7 +15,7 @@ class BaseProcess:
def __init__(
self,
group: None = ...,
target: Callable[..., Any] | None = ...,
target: Callable[..., object] | None = ...,
name: str | None = ...,
args: Iterable[Any] = ...,
kwargs: Mapping[str, Any] = ...,

View File

@@ -4,7 +4,6 @@ from collections.abc import Callable
from contextlib import AbstractContextManager
from multiprocessing.context import BaseContext
from types import TracebackType
from typing import Any
from typing_extensions import TypeAlias
__all__ = ["Lock", "RLock", "Semaphore", "BoundedSemaphore", "Condition", "Event"]
@@ -13,7 +12,7 @@ _LockLike: TypeAlias = Lock | RLock
class Barrier(threading.Barrier):
def __init__(
self, parties: int, action: Callable[..., Any] | None = ..., timeout: float | None = ..., *ctx: BaseContext
self, parties: int, action: Callable[[], object] | None = ..., timeout: float | None = ..., *ctx: BaseContext
) -> None: ...
class BoundedSemaphore(Semaphore):