multiprocessing Pool (and context manager) fixes/improvements (#1562)

* Use typing.ContextManager for multiprocessing context managers

Prior to this commit, the types for __enter__ and __exit__ were not
fully defined; this addresses that.

* Move Pool class stub to multiprocessing.pool

This is where the class is actually defined in the stdlib.

* Ensure that __enter__ on Pool subclasses returns the subclass

This ensures that:

```py
class MyPool(Pool):
    def my_method(self): pass

with MyPool() as pool:
    pool.my_method()
```

type-checks correctly.

* Update the signature of BaseContext.Pool to match Pool.__init__

* Restore multiprocessing.Pool as a function

And also add comments to note that it should have an identical signature
to multiprocessing.context.BaseContext.Pool (because it is just that
method partially applied).
This commit is contained in:
Daniel Watkins
2017-08-29 05:19:08 +01:00
committed by Jelle Zijlstra
parent 87009939a5
commit 85a788dbca
4 changed files with 45 additions and 67 deletions

View File

@@ -1,8 +1,12 @@
# Stubs for multiprocessing
from typing import Any, Callable, Iterable, Mapping, Optional, Dict, List, Union, TypeVar
from typing import (
Any, Callable, ContextManager, Iterable, Mapping, Optional, Dict, List,
Union, TypeVar,
)
from logging import Logger
from multiprocessing import pool
from multiprocessing.context import BaseContext
from multiprocessing.managers import SyncManager
from multiprocessing.pool import AsyncResult
@@ -12,11 +16,9 @@ import queue
_T = TypeVar('_T')
class Lock():
class Lock(ContextManager[Lock]):
def acquire(self, block: bool = ..., timeout: int = ...) -> None: ...
def release(self) -> None: ...
def __enter__(self) -> 'Lock': ...
def __exit__(self, exc_type, exc_value, tb) -> None: ...
class Event(object):
def __init__(self, *, ctx: BaseContext) -> None: ...
@@ -25,54 +27,13 @@ class Event(object):
def clear(self) -> None: ...
def wait(self, timeout: Optional[int] = ...) -> bool: ...
class Pool():
def __init__(self, processes: Optional[int] = ...,
initializer: Optional[Callable[..., None]] = ...,
initargs: Iterable[Any] = ...,
maxtasksperchild: Optional[int] = ...,
context: Optional[Any] = None) -> None: ...
def apply(self,
func: Callable[..., Any],
args: Iterable[Any] = ...,
kwds: Dict[str, Any]=...) -> Any: ...
def apply_async(self,
func: Callable[..., Any],
args: Iterable[Any] = ...,
kwds: Dict[str, Any] = ...,
callback: Optional[Callable[..., None]] = None,
error_callback: Optional[Callable[[BaseException], None]] = None) -> AsyncResult: ...
def map(self,
func: Callable[..., Any],
iterable: Iterable[Any] = ...,
chunksize: Optional[int] = ...) -> List[Any]: ...
def map_async(self, func: Callable[..., Any],
iterable: Iterable[Any] = ...,
chunksize: Optional[int] = ...,
callback: Optional[Callable[..., None]] = None,
error_callback: Optional[Callable[[BaseException], None]] = None) -> AsyncResult: ...
def imap(self,
func: Callable[..., Any],
iterable: Iterable[Any] = ...,
chunksize: Optional[int] = None) -> Iterable[Any]: ...
def imap_unordered(self,
func: Callable[..., Any],
iterable: Iterable[Any] = ...,
chunksize: Optional[int] = None) -> Iterable[Any]: ...
def starmap(self,
func: Callable[..., Any],
iterable: Iterable[Iterable[Any]] = ...,
chunksize: Optional[int] = None) -> List[Any]: ...
def starmap_async(self,
func: Callable[..., Any],
iterable: Iterable[Iterable[Any]] = ...,
chunksize: Optional[int] = ...,
callback: Optional[Callable[..., None]] = None,
error_callback: Optional[Callable[[BaseException], None]] = None) -> AsyncResult: ...
def close(self) -> None: ...
def terminate(self) -> None: ...
def join(self) -> None: ...
def __enter__(self) -> 'Pool': ...
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
# N.B. This is generated at runtime by partially applying
# multiprocessing.context.BaseContext.Pool, so the two signatures should be
# identical (modulo self).
def Pool(processes: Optional[int] = ...,
initializer: Optional[Callable[..., Any]] = ...,
initargs: Iterable[Any] = ...,
maxtasksperchild: Optional[int] = ...) -> pool.Pool: ...
class Process():
name: str

View File

@@ -3,7 +3,9 @@
from logging import Logger
import multiprocessing
import sys
from typing import Any, Callable, Optional, List, Sequence, Tuple, Type, Union
from typing import (
Any, Callable, Iterable, Optional, List, Sequence, Tuple, Type, Union,
)
class ProcessError(Exception): ...
@@ -49,13 +51,16 @@ class BaseContext(object):
def JoinableQueue(self, maxsize: int = ...) -> Any: ...
# TODO: change return to SimpleQueue once a stub exists in multiprocessing.queues
def SimpleQueue(self) -> Any: ...
# N.B. This method is partially applied at runtime to generate
# multiprocessing.Pool, so the two signatures should be identical (modulo
# self).
def Pool(
self,
processes: Optional[int] = ...,
initializer: Optional[Callable[..., Any]] = ...,
initargs: Tuple = ...,
initargs: Iterable[Any] = ...,
maxtasksperchild: Optional[int] = ...
) -> multiprocessing.Pool: ...
) -> multiprocessing.pool.Pool: ...
# TODO: typecode_or_type param is a ctype with a base class of _SimpleCData or array.typecode Need to figure out
# how to handle the ctype
# TODO: change return to RawValue once a stub exists in multiprocessing.sharedctypes

View File

@@ -5,7 +5,8 @@
import queue
import threading
from typing import (
Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, TypeVar
Any, Callable, ContextManager, Dict, Iterable, List, Mapping, Optional,
Sequence, TypeVar,
)
_T = TypeVar('_T')
@@ -16,13 +17,11 @@ class Namespace: ...
_Namespace = Namespace
class BaseManager:
class BaseManager(ContextManager[BaseManager]):
def register(self, typeid: str, callable: Any = ...) -> None: ...
def shutdown(self) -> None: ...
def start(self, initializer: Optional[Callable[..., Any]] = ...,
initargs: Iterable[Any] = ...) -> None: ...
def __enter__(self) -> 'BaseManager': ...
def __exit__(self, exc_type, exc_value, tb) -> None: ...
class SyncManager(BaseManager):
def BoundedSemaphore(self, value: Any = ...) -> threading.BoundedSemaphore: ...

View File

@@ -2,7 +2,12 @@
# NOTE: These are incomplete!
from typing import Any, Callable, Iterable, Mapping, Optional, Dict, List
from typing import (
Any, Callable, ContextManager, Iterable, Mapping, Optional, Dict, List,
TypeVar,
)
_T = TypeVar('_T', bound='Pool')
class AsyncResult():
def get(self, timeout: float = ...) -> Any: ...
@@ -10,10 +15,12 @@ class AsyncResult():
def ready(self) -> bool: ...
def successful(self) -> bool: ...
class ThreadPool():
def __init__(self, processes: Optional[int] = None,
initializer: Optional[Callable[..., Any]] = None,
initargs: Iterable[Any] = ...) -> None: ...
class Pool(ContextManager[Pool]):
def __init__(self, processes: Optional[int] = ...,
initializer: Optional[Callable[..., None]] = ...,
initargs: Iterable[Any] = ...,
maxtasksperchild: Optional[int] = ...,
context: Optional[Any] = None) -> None: ...
def apply(self,
func: Callable[..., Any],
args: Iterable[Any] = ...,
@@ -30,7 +37,7 @@ class ThreadPool():
chunksize: Optional[int] = None) -> List[Any]: ...
def map_async(self, func: Callable[..., Any],
iterable: Iterable[Any] = ...,
chunksize: Optional[Optional[int]] = None,
chunksize: Optional[int] = None,
callback: Optional[Callable[..., None]] = None,
error_callback: Optional[Callable[[BaseException], None]] = None) -> AsyncResult: ...
def imap(self,
@@ -54,5 +61,11 @@ class ThreadPool():
def close(self) -> None: ...
def terminate(self) -> None: ...
def join(self) -> None: ...
def __enter__(self) -> 'ThreadPool': ...
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
def __enter__(self: _T) -> _T: ...
class ThreadPool(Pool, ContextManager[ThreadPool]):
def __init__(self, processes: Optional[int] = None,
initializer: Optional[Callable[..., Any]] = None,
initargs: Iterable[Any] = ...) -> None: ...