diff --git a/stdlib/3/multiprocessing/__init__.pyi b/stdlib/3/multiprocessing/__init__.pyi index f797b548a..569e6e394 100644 --- a/stdlib/3/multiprocessing/__init__.pyi +++ b/stdlib/3/multiprocessing/__init__.pyi @@ -10,8 +10,10 @@ from multiprocessing.context import ( AuthenticationError as AuthenticationError, BaseContext, BufferTooShort as BufferTooShort, + DefaultContext, Process as Process, ProcessError as ProcessError, + SpawnContext, TimeoutError as TimeoutError, ) from multiprocessing.managers import SyncManager @@ -22,6 +24,13 @@ from multiprocessing.spawn import set_executable as set_executable if sys.version_info >= (3, 8): from multiprocessing.process import parent_process as parent_process + from typing import Literal +else: + from typing_extensions import Literal + +if sys.platform != "win32": + from multiprocessing.context import ForkContext, ForkServerContext + # N.B. The functions below are generated at runtime by partially applying # multiprocessing.context.BaseContext's methods, so the two signatures should @@ -79,6 +88,25 @@ def log_to_stderr(level: Optional[Union[str, int]] = ...) -> Logger: ... def Manager() -> SyncManager: ... def set_forkserver_preload(module_names: List[str]) -> None: ... def get_all_start_methods() -> List[str]: ... -def get_context(method: Optional[str] = ...) -> BaseContext: ... def get_start_method(allow_none: bool = ...) -> Optional[str]: ... def set_start_method(method: str, force: Optional[bool] = ...) -> None: ... + + +if sys.platform != "win32": + @overload + def get_context(method: None = ...) -> DefaultContext: ... + @overload + def get_context(method: Literal["spawn"]) -> SpawnContext: ... + @overload + def get_context(method: Literal["fork"]) -> ForkContext: ... + @overload + def get_context(method: Literal["forkserver"]) -> ForkServerContext: ... + @overload + def get_context(method: str) -> BaseContext: ... +else: + @overload + def get_context(method: None = ...) -> DefaultContext: ... + @overload + def get_context(method: Literal["spawn"]) -> SpawnContext: ... + @overload + def get_context(method: str) -> BaseContext: ... diff --git a/stdlib/3/multiprocessing/context.pyi b/stdlib/3/multiprocessing/context.pyi index 627c760f0..49b4e8d3d 100644 --- a/stdlib/3/multiprocessing/context.pyi +++ b/stdlib/3/multiprocessing/context.pyi @@ -6,7 +6,12 @@ from multiprocessing import synchronize from multiprocessing import queues from multiprocessing.process import BaseProcess import sys -from typing import Any, Callable, Iterable, Optional, List, Mapping, Sequence, Type, Union +from typing import Any, Callable, Iterable, Optional, List, Mapping, Sequence, Type, Union, overload + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal _LockLike = Union[synchronize.Lock, synchronize.RLock] @@ -19,6 +24,7 @@ class TimeoutError(ProcessError): ... class AuthenticationError(ProcessError): ... class BaseContext(object): + Process: Type[BaseProcess] ProcessError: Type[Exception] BufferTooShort: Type[Exception] TimeoutError: Type[Exception] @@ -96,7 +102,26 @@ class BaseContext(object): def allow_connection_pickling(self) -> None: ... def set_executable(self, executable: str) -> None: ... def set_forkserver_preload(self, module_names: List[str]) -> None: ... - def get_context(self, method: Optional[str] = ...) -> BaseContext: ... + + if sys.platform != "win32": + @overload + def get_context(self, method: None = ...) -> DefaultContext: ... + @overload + def get_context(self, method: Literal["spawn"]) -> SpawnContext: ... + @overload + def get_context(self, method: Literal["fork"]) -> ForkContext: ... + @overload + def get_context(self, method: Literal["forkserver"]) -> ForkServerContext: ... + @overload + def get_context(self, method: str) -> BaseContext: ... + else: + @overload + def get_context(self, method: None = ...) -> DefaultContext: ... + @overload + def get_context(self, method: Literal["spawn"]) -> SpawnContext: ... + @overload + def get_context(self, method: str) -> BaseContext: ... + def get_start_method(self, allow_none: bool = ...) -> str: ... def set_start_method(self, method: Optional[str], force: bool = ...) -> None: ... @property @@ -114,7 +139,6 @@ class DefaultContext(BaseContext): Process: Type[multiprocessing.Process] def __init__(self, context: BaseContext) -> None: ... - def get_context(self, method: Optional[str] = ...) -> BaseContext: ... def set_start_method(self, method: Optional[str], force: bool = ...) -> None: ... def get_start_method(self, allow_none: bool = ...) -> str: ... def get_all_start_methods(self) -> List[str]: ...