Use overloads for more precise get_context type (#3605)

This commit is contained in:
Alan Du
2020-01-15 12:43:46 -05:00
committed by Sebastian Rittau
parent 91d000e434
commit 42d68dd765
2 changed files with 56 additions and 4 deletions

View File

@@ -10,8 +10,10 @@ from multiprocessing.context import (
AuthenticationError as AuthenticationError,
BaseContext,
BufferTooShort as BufferTooShort,
DefaultContext,
Process as Process,
ProcessError as ProcessError,
SpawnContext,
TimeoutError as TimeoutError,
)
from multiprocessing.managers import SyncManager
@@ -22,6 +24,13 @@ from multiprocessing.spawn import set_executable as set_executable
if sys.version_info >= (3, 8):
from multiprocessing.process import parent_process as parent_process
from typing import Literal
else:
from typing_extensions import Literal
if sys.platform != "win32":
from multiprocessing.context import ForkContext, ForkServerContext
# N.B. The functions below are generated at runtime by partially applying
# multiprocessing.context.BaseContext's methods, so the two signatures should
@@ -79,6 +88,25 @@ def log_to_stderr(level: Optional[Union[str, int]] = ...) -> Logger: ...
def Manager() -> SyncManager: ...
def set_forkserver_preload(module_names: List[str]) -> None: ...
def get_all_start_methods() -> List[str]: ...
def get_context(method: Optional[str] = ...) -> BaseContext: ...
def get_start_method(allow_none: bool = ...) -> Optional[str]: ...
def set_start_method(method: str, force: Optional[bool] = ...) -> None: ...
if sys.platform != "win32":
@overload
def get_context(method: None = ...) -> DefaultContext: ...
@overload
def get_context(method: Literal["spawn"]) -> SpawnContext: ...
@overload
def get_context(method: Literal["fork"]) -> ForkContext: ...
@overload
def get_context(method: Literal["forkserver"]) -> ForkServerContext: ...
@overload
def get_context(method: str) -> BaseContext: ...
else:
@overload
def get_context(method: None = ...) -> DefaultContext: ...
@overload
def get_context(method: Literal["spawn"]) -> SpawnContext: ...
@overload
def get_context(method: str) -> BaseContext: ...

View File

@@ -6,7 +6,12 @@ from multiprocessing import synchronize
from multiprocessing import queues
from multiprocessing.process import BaseProcess
import sys
from typing import Any, Callable, Iterable, Optional, List, Mapping, Sequence, Type, Union
from typing import Any, Callable, Iterable, Optional, List, Mapping, Sequence, Type, Union, overload
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
_LockLike = Union[synchronize.Lock, synchronize.RLock]
@@ -19,6 +24,7 @@ class TimeoutError(ProcessError): ...
class AuthenticationError(ProcessError): ...
class BaseContext(object):
Process: Type[BaseProcess]
ProcessError: Type[Exception]
BufferTooShort: Type[Exception]
TimeoutError: Type[Exception]
@@ -96,7 +102,26 @@ class BaseContext(object):
def allow_connection_pickling(self) -> None: ...
def set_executable(self, executable: str) -> None: ...
def set_forkserver_preload(self, module_names: List[str]) -> None: ...
def get_context(self, method: Optional[str] = ...) -> BaseContext: ...
if sys.platform != "win32":
@overload
def get_context(self, method: None = ...) -> DefaultContext: ...
@overload
def get_context(self, method: Literal["spawn"]) -> SpawnContext: ...
@overload
def get_context(self, method: Literal["fork"]) -> ForkContext: ...
@overload
def get_context(self, method: Literal["forkserver"]) -> ForkServerContext: ...
@overload
def get_context(self, method: str) -> BaseContext: ...
else:
@overload
def get_context(self, method: None = ...) -> DefaultContext: ...
@overload
def get_context(self, method: Literal["spawn"]) -> SpawnContext: ...
@overload
def get_context(self, method: str) -> BaseContext: ...
def get_start_method(self, allow_none: bool = ...) -> str: ...
def set_start_method(self, method: Optional[str], force: bool = ...) -> None: ...
@property
@@ -114,7 +139,6 @@ class DefaultContext(BaseContext):
Process: Type[multiprocessing.Process]
def __init__(self, context: BaseContext) -> None: ...
def get_context(self, method: Optional[str] = ...) -> BaseContext: ...
def set_start_method(self, method: Optional[str], force: bool = ...) -> None: ...
def get_start_method(self, allow_none: bool = ...) -> str: ...
def get_all_start_methods(self) -> List[str]: ...