From de12b1413d3f5c37d50fcfa39fc7e9f226f3241c Mon Sep 17 00:00:00 2001 From: Bruce Merry <1963944+bmerry@users.noreply.github.com> Date: Wed, 16 Nov 2022 17:00:38 +0200 Subject: [PATCH] Refine types for asyncio transports (#9209) - Change the return type of create_connection, start_tls, connect_accepted_socket, create_unix_connection to Transport rather than BaseTransport (closes #9199). - Change the return type of create_datagram_endpoint to DatagramTransport rather than BaseTransport. - Change the argument of sendfile to WriteTransport rather than BaseTransport. I considered also changing the argument of start_tls to Transport, but I think that will give false positives for code that implements a custom transport class that inherits from both ReadTransport and WriteTransport but not from Transport, and I'm not sure if typing has a way to express an intersection of types. Since users are not normally expected to implement transports that may be overthinking things. --- stdlib/asyncio/base_events.pyi | 28 ++++++++++++++-------------- stdlib/asyncio/events.pyi | 32 ++++++++++++++++---------------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/stdlib/asyncio/base_events.pyi b/stdlib/asyncio/base_events.pyi index c1ab114b6..83576ab64 100644 --- a/stdlib/asyncio/base_events.pyi +++ b/stdlib/asyncio/base_events.pyi @@ -5,7 +5,7 @@ from asyncio.events import AbstractEventLoop, AbstractServer, Handle, TimerHandl from asyncio.futures import Future from asyncio.protocols import BaseProtocol from asyncio.tasks import Task -from asyncio.transports import BaseTransport, ReadTransport, SubprocessTransport, WriteTransport +from asyncio.transports import BaseTransport, DatagramTransport, ReadTransport, SubprocessTransport, Transport, WriteTransport from collections.abc import Awaitable, Callable, Coroutine, Generator, Iterable, Sequence from contextvars import Context from socket import AddressFamily, SocketKind, _Address, _RetAddress, socket @@ -129,7 +129,7 @@ class BaseEventLoop(AbstractEventLoop): ssl_shutdown_timeout: float | None = ..., happy_eyeballs_delay: float | None = ..., interleave: int | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... @overload async def create_connection( self, @@ -148,7 +148,7 @@ class BaseEventLoop(AbstractEventLoop): ssl_shutdown_timeout: float | None = ..., happy_eyeballs_delay: float | None = ..., interleave: int | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... elif sys.version_info >= (3, 8): @overload async def create_connection( @@ -167,7 +167,7 @@ class BaseEventLoop(AbstractEventLoop): ssl_handshake_timeout: float | None = ..., happy_eyeballs_delay: float | None = ..., interleave: int | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... @overload async def create_connection( self, @@ -185,7 +185,7 @@ class BaseEventLoop(AbstractEventLoop): ssl_handshake_timeout: float | None = ..., happy_eyeballs_delay: float | None = ..., interleave: int | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... else: @overload async def create_connection( @@ -202,7 +202,7 @@ class BaseEventLoop(AbstractEventLoop): local_addr: tuple[str, int] | None = ..., server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... @overload async def create_connection( self, @@ -218,7 +218,7 @@ class BaseEventLoop(AbstractEventLoop): local_addr: None = ..., server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... if sys.version_info >= (3, 11): @overload async def create_server( @@ -266,7 +266,7 @@ class BaseEventLoop(AbstractEventLoop): server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., ssl_shutdown_timeout: float | None = ..., - ) -> BaseTransport: ... + ) -> Transport: ... async def connect_accepted_socket( self, protocol_factory: Callable[[], _ProtocolT], @@ -275,7 +275,7 @@ class BaseEventLoop(AbstractEventLoop): ssl: _SSLContext = ..., ssl_handshake_timeout: float | None = ..., ssl_shutdown_timeout: float | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... else: @overload async def create_server( @@ -320,7 +320,7 @@ class BaseEventLoop(AbstractEventLoop): server_side: bool = ..., server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., - ) -> BaseTransport: ... + ) -> Transport: ... async def connect_accepted_socket( self, protocol_factory: Callable[[], _ProtocolT], @@ -328,13 +328,13 @@ class BaseEventLoop(AbstractEventLoop): *, ssl: _SSLContext = ..., ssl_handshake_timeout: float | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... async def sock_sendfile( self, sock: socket, file: IO[bytes], offset: int = ..., count: int | None = ..., *, fallback: bool | None = ... ) -> int: ... async def sendfile( - self, transport: BaseTransport, file: IO[bytes], offset: int = ..., count: int | None = ..., *, fallback: bool = ... + self, transport: WriteTransport, file: IO[bytes], offset: int = ..., count: int | None = ..., *, fallback: bool = ... ) -> int: ... if sys.version_info >= (3, 11): async def create_datagram_endpoint( # type: ignore[override] @@ -349,7 +349,7 @@ class BaseEventLoop(AbstractEventLoop): reuse_port: bool | None = ..., allow_broadcast: bool | None = ..., sock: socket | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[DatagramTransport, _ProtocolT]: ... else: async def create_datagram_endpoint( self, @@ -364,7 +364,7 @@ class BaseEventLoop(AbstractEventLoop): reuse_port: bool | None = ..., allow_broadcast: bool | None = ..., sock: socket | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[DatagramTransport, _ProtocolT]: ... # Pipes and subprocesses. async def connect_read_pipe( self, protocol_factory: Callable[[], _ProtocolT], pipe: Any diff --git a/stdlib/asyncio/events.pyi b/stdlib/asyncio/events.pyi index 280be4ab5..7241d5a29 100644 --- a/stdlib/asyncio/events.pyi +++ b/stdlib/asyncio/events.pyi @@ -12,7 +12,7 @@ from .base_events import Server from .futures import Future from .protocols import BaseProtocol from .tasks import Task -from .transports import BaseTransport, ReadTransport, SubprocessTransport, WriteTransport +from .transports import BaseTransport, DatagramTransport, ReadTransport, SubprocessTransport, Transport, WriteTransport from .unix_events import AbstractChildWatcher if sys.version_info >= (3, 8): @@ -223,7 +223,7 @@ class AbstractEventLoop: ssl_shutdown_timeout: float | None = ..., happy_eyeballs_delay: float | None = ..., interleave: int | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... @overload @abstractmethod async def create_connection( @@ -243,7 +243,7 @@ class AbstractEventLoop: ssl_shutdown_timeout: float | None = ..., happy_eyeballs_delay: float | None = ..., interleave: int | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... elif sys.version_info >= (3, 8): @overload @abstractmethod @@ -263,7 +263,7 @@ class AbstractEventLoop: ssl_handshake_timeout: float | None = ..., happy_eyeballs_delay: float | None = ..., interleave: int | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... @overload @abstractmethod async def create_connection( @@ -282,7 +282,7 @@ class AbstractEventLoop: ssl_handshake_timeout: float | None = ..., happy_eyeballs_delay: float | None = ..., interleave: int | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... else: @overload @abstractmethod @@ -300,7 +300,7 @@ class AbstractEventLoop: local_addr: tuple[str, int] | None = ..., server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... @overload @abstractmethod async def create_connection( @@ -317,7 +317,7 @@ class AbstractEventLoop: local_addr: None = ..., server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... if sys.version_info >= (3, 11): @overload @abstractmethod @@ -360,7 +360,7 @@ class AbstractEventLoop: @abstractmethod async def start_tls( self, - transport: BaseTransport, + transport: WriteTransport, protocol: BaseProtocol, sslcontext: ssl.SSLContext, *, @@ -368,7 +368,7 @@ class AbstractEventLoop: server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., ssl_shutdown_timeout: float | None = ..., - ) -> BaseTransport: ... + ) -> Transport: ... async def create_unix_server( self, protocol_factory: _ProtocolFactory, @@ -428,7 +428,7 @@ class AbstractEventLoop: server_side: bool = ..., server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., - ) -> BaseTransport: ... + ) -> Transport: ... async def create_unix_server( self, protocol_factory: _ProtocolFactory, @@ -449,7 +449,7 @@ class AbstractEventLoop: ssl: _SSLContext = ..., ssl_handshake_timeout: float | None = ..., ssl_shutdown_timeout: float | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... elif sys.version_info >= (3, 10): async def connect_accepted_socket( self, @@ -458,7 +458,7 @@ class AbstractEventLoop: *, ssl: _SSLContext = ..., ssl_handshake_timeout: float | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... if sys.version_info >= (3, 11): async def create_unix_connection( self, @@ -470,7 +470,7 @@ class AbstractEventLoop: server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., ssl_shutdown_timeout: float | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... else: async def create_unix_connection( self, @@ -481,7 +481,7 @@ class AbstractEventLoop: sock: socket | None = ..., server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[Transport, _ProtocolT]: ... @abstractmethod async def sock_sendfile( @@ -489,7 +489,7 @@ class AbstractEventLoop: ) -> int: ... @abstractmethod async def sendfile( - self, transport: BaseTransport, file: IO[bytes], offset: int = ..., count: int | None = ..., *, fallback: bool = ... + self, transport: WriteTransport, file: IO[bytes], offset: int = ..., count: int | None = ..., *, fallback: bool = ... ) -> int: ... @abstractmethod async def create_datagram_endpoint( @@ -505,7 +505,7 @@ class AbstractEventLoop: reuse_port: bool | None = ..., allow_broadcast: bool | None = ..., sock: socket | None = ..., - ) -> tuple[BaseTransport, _ProtocolT]: ... + ) -> tuple[DatagramTransport, _ProtocolT]: ... # Pipes and subprocesses. @abstractmethod async def connect_read_pipe(