diff --git a/stdlib/contextlib.pyi b/stdlib/contextlib.pyi index 509bcb6e4..19ef45bed 100644 --- a/stdlib/contextlib.pyi +++ b/stdlib/contextlib.pyi @@ -1,6 +1,7 @@ import sys from types import TracebackType from typing import IO, Any, Callable, ContextManager, Iterable, Iterator, Optional, Type, TypeVar, overload +from typing_extensions import Protocol if sys.version_info >= (3, 5): from typing import AsyncContextManager, AsyncIterator @@ -34,8 +35,20 @@ if sys.version_info >= (3, 7): if sys.version_info < (3,): def nested(*mgr: ContextManager[Any]) -> ContextManager[Iterable[Any]]: ... -class closing(ContextManager[_T]): - def __init__(self, thing: _T) -> None: ... +class _SupportsClose(Protocol): + def close(self) -> None: ... + +_SupportsCloseT = TypeVar("_SupportsCloseT", bound=_SupportsClose) + +class closing(ContextManager[_SupportsCloseT]): + def __init__(self, thing: _SupportsCloseT) -> None: ... + +if sys.version_info >= (3, 10): + class _SupportsAclose(Protocol): + async def aclose(self) -> None: ... + _SupportsAcloseT = TypeVar("_SupportsAcloseT", bound=_SupportsAclose) + class aclosing(AsyncContextManager[_SupportsAcloseT]): + def __init__(self, thing: _SupportsAcloseT) -> None: ... if sys.version_info >= (3, 4): class suppress(ContextManager[None]):