Refine types for asyncio transports (#9209)

- Change the return type of create_connection, start_tls,
  connect_accepted_socket, create_unix_connection to Transport
  rather than BaseTransport (closes #9199).
- Change the return type of create_datagram_endpoint to
  DatagramTransport rather than BaseTransport.
- Change the argument of sendfile to WriteTransport rather than
  BaseTransport.

I considered also changing the argument of start_tls to Transport, but
I think that will give false positives for code that implements a custom
transport class that inherits from both ReadTransport and WriteTransport
but not from Transport, and I'm not sure if typing has a way to express
an intersection of types. Since users are not normally expected to
implement transports that may be overthinking things.
This commit is contained in:
Bruce Merry
2022-11-16 17:00:38 +02:00
committed by GitHub
parent 41d508472a
commit de12b1413d
2 changed files with 30 additions and 30 deletions

View File

@@ -5,7 +5,7 @@ from asyncio.events import AbstractEventLoop, AbstractServer, Handle, TimerHandl
from asyncio.futures import Future
from asyncio.protocols import BaseProtocol
from asyncio.tasks import Task
from asyncio.transports import BaseTransport, ReadTransport, SubprocessTransport, WriteTransport
from asyncio.transports import BaseTransport, DatagramTransport, ReadTransport, SubprocessTransport, Transport, WriteTransport
from collections.abc import Awaitable, Callable, Coroutine, Generator, Iterable, Sequence
from contextvars import Context
from socket import AddressFamily, SocketKind, _Address, _RetAddress, socket
@@ -129,7 +129,7 @@ class BaseEventLoop(AbstractEventLoop):
ssl_shutdown_timeout: float | None = ...,
happy_eyeballs_delay: float | None = ...,
interleave: int | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
@overload
async def create_connection(
self,
@@ -148,7 +148,7 @@ class BaseEventLoop(AbstractEventLoop):
ssl_shutdown_timeout: float | None = ...,
happy_eyeballs_delay: float | None = ...,
interleave: int | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
elif sys.version_info >= (3, 8):
@overload
async def create_connection(
@@ -167,7 +167,7 @@ class BaseEventLoop(AbstractEventLoop):
ssl_handshake_timeout: float | None = ...,
happy_eyeballs_delay: float | None = ...,
interleave: int | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
@overload
async def create_connection(
self,
@@ -185,7 +185,7 @@ class BaseEventLoop(AbstractEventLoop):
ssl_handshake_timeout: float | None = ...,
happy_eyeballs_delay: float | None = ...,
interleave: int | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
else:
@overload
async def create_connection(
@@ -202,7 +202,7 @@ class BaseEventLoop(AbstractEventLoop):
local_addr: tuple[str, int] | None = ...,
server_hostname: str | None = ...,
ssl_handshake_timeout: float | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
@overload
async def create_connection(
self,
@@ -218,7 +218,7 @@ class BaseEventLoop(AbstractEventLoop):
local_addr: None = ...,
server_hostname: str | None = ...,
ssl_handshake_timeout: float | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
if sys.version_info >= (3, 11):
@overload
async def create_server(
@@ -266,7 +266,7 @@ class BaseEventLoop(AbstractEventLoop):
server_hostname: str | None = ...,
ssl_handshake_timeout: float | None = ...,
ssl_shutdown_timeout: float | None = ...,
) -> BaseTransport: ...
) -> Transport: ...
async def connect_accepted_socket(
self,
protocol_factory: Callable[[], _ProtocolT],
@@ -275,7 +275,7 @@ class BaseEventLoop(AbstractEventLoop):
ssl: _SSLContext = ...,
ssl_handshake_timeout: float | None = ...,
ssl_shutdown_timeout: float | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
else:
@overload
async def create_server(
@@ -320,7 +320,7 @@ class BaseEventLoop(AbstractEventLoop):
server_side: bool = ...,
server_hostname: str | None = ...,
ssl_handshake_timeout: float | None = ...,
) -> BaseTransport: ...
) -> Transport: ...
async def connect_accepted_socket(
self,
protocol_factory: Callable[[], _ProtocolT],
@@ -328,13 +328,13 @@ class BaseEventLoop(AbstractEventLoop):
*,
ssl: _SSLContext = ...,
ssl_handshake_timeout: float | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
async def sock_sendfile(
self, sock: socket, file: IO[bytes], offset: int = ..., count: int | None = ..., *, fallback: bool | None = ...
) -> int: ...
async def sendfile(
self, transport: BaseTransport, file: IO[bytes], offset: int = ..., count: int | None = ..., *, fallback: bool = ...
self, transport: WriteTransport, file: IO[bytes], offset: int = ..., count: int | None = ..., *, fallback: bool = ...
) -> int: ...
if sys.version_info >= (3, 11):
async def create_datagram_endpoint( # type: ignore[override]
@@ -349,7 +349,7 @@ class BaseEventLoop(AbstractEventLoop):
reuse_port: bool | None = ...,
allow_broadcast: bool | None = ...,
sock: socket | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[DatagramTransport, _ProtocolT]: ...
else:
async def create_datagram_endpoint(
self,
@@ -364,7 +364,7 @@ class BaseEventLoop(AbstractEventLoop):
reuse_port: bool | None = ...,
allow_broadcast: bool | None = ...,
sock: socket | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[DatagramTransport, _ProtocolT]: ...
# Pipes and subprocesses.
async def connect_read_pipe(
self, protocol_factory: Callable[[], _ProtocolT], pipe: Any

View File

@@ -12,7 +12,7 @@ from .base_events import Server
from .futures import Future
from .protocols import BaseProtocol
from .tasks import Task
from .transports import BaseTransport, ReadTransport, SubprocessTransport, WriteTransport
from .transports import BaseTransport, DatagramTransport, ReadTransport, SubprocessTransport, Transport, WriteTransport
from .unix_events import AbstractChildWatcher
if sys.version_info >= (3, 8):
@@ -223,7 +223,7 @@ class AbstractEventLoop:
ssl_shutdown_timeout: float | None = ...,
happy_eyeballs_delay: float | None = ...,
interleave: int | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
@overload
@abstractmethod
async def create_connection(
@@ -243,7 +243,7 @@ class AbstractEventLoop:
ssl_shutdown_timeout: float | None = ...,
happy_eyeballs_delay: float | None = ...,
interleave: int | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
elif sys.version_info >= (3, 8):
@overload
@abstractmethod
@@ -263,7 +263,7 @@ class AbstractEventLoop:
ssl_handshake_timeout: float | None = ...,
happy_eyeballs_delay: float | None = ...,
interleave: int | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
@overload
@abstractmethod
async def create_connection(
@@ -282,7 +282,7 @@ class AbstractEventLoop:
ssl_handshake_timeout: float | None = ...,
happy_eyeballs_delay: float | None = ...,
interleave: int | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
else:
@overload
@abstractmethod
@@ -300,7 +300,7 @@ class AbstractEventLoop:
local_addr: tuple[str, int] | None = ...,
server_hostname: str | None = ...,
ssl_handshake_timeout: float | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
@overload
@abstractmethod
async def create_connection(
@@ -317,7 +317,7 @@ class AbstractEventLoop:
local_addr: None = ...,
server_hostname: str | None = ...,
ssl_handshake_timeout: float | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
if sys.version_info >= (3, 11):
@overload
@abstractmethod
@@ -360,7 +360,7 @@ class AbstractEventLoop:
@abstractmethod
async def start_tls(
self,
transport: BaseTransport,
transport: WriteTransport,
protocol: BaseProtocol,
sslcontext: ssl.SSLContext,
*,
@@ -368,7 +368,7 @@ class AbstractEventLoop:
server_hostname: str | None = ...,
ssl_handshake_timeout: float | None = ...,
ssl_shutdown_timeout: float | None = ...,
) -> BaseTransport: ...
) -> Transport: ...
async def create_unix_server(
self,
protocol_factory: _ProtocolFactory,
@@ -428,7 +428,7 @@ class AbstractEventLoop:
server_side: bool = ...,
server_hostname: str | None = ...,
ssl_handshake_timeout: float | None = ...,
) -> BaseTransport: ...
) -> Transport: ...
async def create_unix_server(
self,
protocol_factory: _ProtocolFactory,
@@ -449,7 +449,7 @@ class AbstractEventLoop:
ssl: _SSLContext = ...,
ssl_handshake_timeout: float | None = ...,
ssl_shutdown_timeout: float | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
elif sys.version_info >= (3, 10):
async def connect_accepted_socket(
self,
@@ -458,7 +458,7 @@ class AbstractEventLoop:
*,
ssl: _SSLContext = ...,
ssl_handshake_timeout: float | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
if sys.version_info >= (3, 11):
async def create_unix_connection(
self,
@@ -470,7 +470,7 @@ class AbstractEventLoop:
server_hostname: str | None = ...,
ssl_handshake_timeout: float | None = ...,
ssl_shutdown_timeout: float | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
else:
async def create_unix_connection(
self,
@@ -481,7 +481,7 @@ class AbstractEventLoop:
sock: socket | None = ...,
server_hostname: str | None = ...,
ssl_handshake_timeout: float | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[Transport, _ProtocolT]: ...
@abstractmethod
async def sock_sendfile(
@@ -489,7 +489,7 @@ class AbstractEventLoop:
) -> int: ...
@abstractmethod
async def sendfile(
self, transport: BaseTransport, file: IO[bytes], offset: int = ..., count: int | None = ..., *, fallback: bool = ...
self, transport: WriteTransport, file: IO[bytes], offset: int = ..., count: int | None = ..., *, fallback: bool = ...
) -> int: ...
@abstractmethod
async def create_datagram_endpoint(
@@ -505,7 +505,7 @@ class AbstractEventLoop:
reuse_port: bool | None = ...,
allow_broadcast: bool | None = ...,
sock: socket | None = ...,
) -> tuple[BaseTransport, _ProtocolT]: ...
) -> tuple[DatagramTransport, _ProtocolT]: ...
# Pipes and subprocesses.
@abstractmethod
async def connect_read_pipe(