diff --git a/stdlib/asyncio/base_events.pyi b/stdlib/asyncio/base_events.pyi index 310a9f585..e413730bc 100644 --- a/stdlib/asyncio/base_events.pyi +++ b/stdlib/asyncio/base_events.pyi @@ -1,7 +1,7 @@ import ssl import sys from _typeshed import FileDescriptorLike, WriteableBuffer -from asyncio.events import AbstractEventLoop, AbstractServer, Handle, TimerHandle +from asyncio.events import AbstractEventLoop, AbstractServer, Handle, TimerHandle, _TaskFactory from asyncio.futures import Future from asyncio.protocols import BaseProtocol from asyncio.tasks import Task @@ -107,8 +107,8 @@ class BaseEventLoop(AbstractEventLoop): else: def create_task(self, coro: Coroutine[Any, Any, _T] | Generator[Any, None, _T]) -> Task[_T]: ... - def set_task_factory(self, factory: Callable[[AbstractEventLoop, Generator[Any, None, _T]], Future[_T]] | None) -> None: ... - def get_task_factory(self) -> Callable[[AbstractEventLoop, Generator[Any, None, _T]], Future[_T]] | None: ... + def set_task_factory(self, factory: _TaskFactory | None) -> None: ... + def get_task_factory(self) -> _TaskFactory | None: ... # Methods for interacting with threads if sys.version_info >= (3, 7): def call_soon_threadsafe(self, callback: Callable[..., Any], *args: Any, context: Context | None = ...) -> Handle: ... diff --git a/stdlib/asyncio/events.pyi b/stdlib/asyncio/events.pyi index 8396f0957..fb4dac56f 100644 --- a/stdlib/asyncio/events.pyi +++ b/stdlib/asyncio/events.pyi @@ -4,7 +4,7 @@ from _typeshed import FileDescriptorLike, Self, WriteableBuffer from abc import ABCMeta, abstractmethod from collections.abc import Awaitable, Callable, Coroutine, Generator, Sequence from socket import AddressFamily, SocketKind, _Address, _RetAddress, socket -from typing import IO, Any, TypeVar, overload +from typing import IO, Any, Protocol, TypeVar, overload from typing_extensions import Literal, TypeAlias from .base_events import Server @@ -81,6 +81,11 @@ _ExceptionHandler: TypeAlias = Callable[[AbstractEventLoop, _Context], Any] _ProtocolFactory: TypeAlias = Callable[[], BaseProtocol] _SSLContext: TypeAlias = bool | None | ssl.SSLContext +class _TaskFactory(Protocol): + def __call__( + self, __loop: AbstractEventLoop, __factory: Coroutine[Any, Any, _T] | Generator[Any, None, _T] + ) -> Future[_T]: ... + class Handle: _cancelled: bool _args: Sequence[Any] @@ -203,9 +208,9 @@ class AbstractEventLoop: def create_task(self, coro: Coroutine[Any, Any, _T] | Generator[Any, None, _T]) -> Task[_T]: ... @abstractmethod - def set_task_factory(self, factory: Callable[[AbstractEventLoop, Generator[Any, None, _T]], Future[_T]] | None) -> None: ... + def set_task_factory(self, factory: _TaskFactory | None) -> None: ... @abstractmethod - def get_task_factory(self) -> Callable[[AbstractEventLoop, Generator[Any, None, _T]], Future[_T]] | None: ... + def get_task_factory(self) -> _TaskFactory | None: ... # Methods for interacting with threads if sys.version_info >= (3, 9): # "context" added in 3.9.10/3.10.2 @abstractmethod