Resolve asyncio issues for 3.13 (#12327)

This commit is contained in:
Max Muoto
2024-07-13 03:12:15 -05:00
committed by GitHub
parent 6a9b53e719
commit 670ca3125b
11 changed files with 257 additions and 116 deletions

View File

@@ -49,6 +49,10 @@ class Server(AbstractServer):
ssl_handshake_timeout: float | None,
) -> None: ...
if sys.version_info >= (3, 13):
def close_clients(self) -> None: ...
def abort_clients(self) -> None: ...
def get_loop(self) -> AbstractEventLoop: ...
def is_serving(self) -> bool: ...
async def start_serving(self) -> None: ...
@@ -222,7 +226,48 @@ class BaseEventLoop(AbstractEventLoop):
happy_eyeballs_delay: float | None = None,
interleave: int | None = None,
) -> tuple[Transport, _ProtocolT]: ...
if sys.version_info >= (3, 11):
if sys.version_info >= (3, 13):
# 3.13 added `keep_alive`.
@overload
async def create_server(
self,
protocol_factory: _ProtocolFactory,
host: str | Sequence[str] | None = None,
port: int = ...,
*,
family: int = ...,
flags: int = ...,
sock: None = None,
backlog: int = 100,
ssl: _SSLContext = None,
reuse_address: bool | None = None,
reuse_port: bool | None = None,
keep_alive: bool | None = None,
ssl_handshake_timeout: float | None = None,
ssl_shutdown_timeout: float | None = None,
start_serving: bool = True,
) -> Server: ...
@overload
async def create_server(
self,
protocol_factory: _ProtocolFactory,
host: None = None,
port: None = None,
*,
family: int = ...,
flags: int = ...,
sock: socket = ...,
backlog: int = 100,
ssl: _SSLContext = None,
reuse_address: bool | None = None,
reuse_port: bool | None = None,
keep_alive: bool | None = None,
ssl_handshake_timeout: float | None = None,
ssl_shutdown_timeout: float | None = None,
start_serving: bool = True,
) -> Server: ...
elif sys.version_info >= (3, 11):
@overload
async def create_server(
self,
@@ -259,26 +304,6 @@ class BaseEventLoop(AbstractEventLoop):
ssl_shutdown_timeout: float | None = None,
start_serving: bool = True,
) -> Server: ...
async def start_tls(
self,
transport: BaseTransport,
protocol: BaseProtocol,
sslcontext: ssl.SSLContext,
*,
server_side: bool = False,
server_hostname: str | None = None,
ssl_handshake_timeout: float | None = None,
ssl_shutdown_timeout: float | None = None,
) -> Transport | None: ...
async def connect_accepted_socket(
self,
protocol_factory: Callable[[], _ProtocolT],
sock: socket,
*,
ssl: _SSLContext = None,
ssl_handshake_timeout: float | None = None,
ssl_shutdown_timeout: float | None = None,
) -> tuple[Transport, _ProtocolT]: ...
else:
@overload
async def create_server(
@@ -314,6 +339,29 @@ class BaseEventLoop(AbstractEventLoop):
ssl_handshake_timeout: float | None = None,
start_serving: bool = True,
) -> Server: ...
if sys.version_info >= (3, 11):
async def start_tls(
self,
transport: BaseTransport,
protocol: BaseProtocol,
sslcontext: ssl.SSLContext,
*,
server_side: bool = False,
server_hostname: str | None = None,
ssl_handshake_timeout: float | None = None,
ssl_shutdown_timeout: float | None = None,
) -> Transport | None: ...
async def connect_accepted_socket(
self,
protocol_factory: Callable[[], _ProtocolT],
sock: socket,
*,
ssl: _SSLContext = None,
ssl_handshake_timeout: float | None = None,
ssl_shutdown_timeout: float | None = None,
) -> tuple[Transport, _ProtocolT]: ...
else:
async def start_tls(
self,
transport: BaseTransport,

View File

@@ -94,6 +94,12 @@ class TimerHandle(Handle):
class AbstractServer:
@abstractmethod
def close(self) -> None: ...
if sys.version_info >= (3, 13):
@abstractmethod
def close_clients(self) -> None: ...
@abstractmethod
def abort_clients(self) -> None: ...
async def __aenter__(self) -> Self: ...
async def __aexit__(self, *exc: Unused) -> None: ...
@abstractmethod
@@ -272,7 +278,50 @@ class AbstractEventLoop:
happy_eyeballs_delay: float | None = None,
interleave: int | None = None,
) -> tuple[Transport, _ProtocolT]: ...
if sys.version_info >= (3, 11):
if sys.version_info >= (3, 13):
# 3.13 added `keep_alive`.
@overload
@abstractmethod
async def create_server(
self,
protocol_factory: _ProtocolFactory,
host: str | Sequence[str] | None = None,
port: int = ...,
*,
family: int = ...,
flags: int = ...,
sock: None = None,
backlog: int = 100,
ssl: _SSLContext = None,
reuse_address: bool | None = None,
reuse_port: bool | None = None,
keep_alive: bool | None = None,
ssl_handshake_timeout: float | None = None,
ssl_shutdown_timeout: float | None = None,
start_serving: bool = True,
) -> Server: ...
@overload
@abstractmethod
async def create_server(
self,
protocol_factory: _ProtocolFactory,
host: None = None,
port: None = None,
*,
family: int = ...,
flags: int = ...,
sock: socket = ...,
backlog: int = 100,
ssl: _SSLContext = None,
reuse_address: bool | None = None,
reuse_port: bool | None = None,
keep_alive: bool | None = None,
ssl_handshake_timeout: float | None = None,
ssl_shutdown_timeout: float | None = None,
start_serving: bool = True,
) -> Server: ...
elif sys.version_info >= (3, 11):
@overload
@abstractmethod
async def create_server(
@@ -311,30 +360,6 @@ class AbstractEventLoop:
ssl_shutdown_timeout: float | None = None,
start_serving: bool = True,
) -> Server: ...
@abstractmethod
async def start_tls(
self,
transport: WriteTransport,
protocol: BaseProtocol,
sslcontext: ssl.SSLContext,
*,
server_side: bool = False,
server_hostname: str | None = None,
ssl_handshake_timeout: float | None = None,
ssl_shutdown_timeout: float | None = None,
) -> Transport | None: ...
async def create_unix_server(
self,
protocol_factory: _ProtocolFactory,
path: StrPath | None = None,
*,
sock: socket | None = None,
backlog: int = 100,
ssl: _SSLContext = None,
ssl_handshake_timeout: float | None = None,
ssl_shutdown_timeout: float | None = None,
start_serving: bool = True,
) -> Server: ...
else:
@overload
@abstractmethod
@@ -372,6 +397,33 @@ class AbstractEventLoop:
ssl_handshake_timeout: float | None = None,
start_serving: bool = True,
) -> Server: ...
if sys.version_info >= (3, 11):
@abstractmethod
async def start_tls(
self,
transport: WriteTransport,
protocol: BaseProtocol,
sslcontext: ssl.SSLContext,
*,
server_side: bool = False,
server_hostname: str | None = None,
ssl_handshake_timeout: float | None = None,
ssl_shutdown_timeout: float | None = None,
) -> Transport | None: ...
async def create_unix_server(
self,
protocol_factory: _ProtocolFactory,
path: StrPath | None = None,
*,
sock: socket | None = None,
backlog: int = 100,
ssl: _SSLContext = None,
ssl_handshake_timeout: float | None = None,
ssl_shutdown_timeout: float | None = None,
start_serving: bool = True,
) -> Server: ...
else:
@abstractmethod
async def start_tls(
self,
@@ -394,6 +446,7 @@ class AbstractEventLoop:
ssl_handshake_timeout: float | None = None,
start_serving: bool = True,
) -> Server: ...
if sys.version_info >= (3, 11):
async def connect_accepted_socket(
self,

View File

@@ -1,4 +1,5 @@
import functools
import sys
import traceback
from collections.abc import Iterable
from types import FrameType, FunctionType
@@ -14,7 +15,17 @@ _FuncType: TypeAlias = FunctionType | _HasWrapper | functools.partial[Any] | fun
def _get_function_source(func: _FuncType) -> tuple[str, int]: ...
@overload
def _get_function_source(func: object) -> tuple[str, int] | None: ...
def _format_callback_source(func: object, args: Iterable[Any]) -> str: ...
def _format_args_and_kwargs(args: Iterable[Any], kwargs: dict[str, Any]) -> str: ...
def _format_callback(func: object, args: Iterable[Any], kwargs: dict[str, Any], suffix: str = "") -> str: ...
if sys.version_info >= (3, 13):
def _format_callback_source(func: object, args: Iterable[Any], *, debug: bool = False) -> str: ...
def _format_args_and_kwargs(args: Iterable[Any], kwargs: dict[str, Any], *, debug: bool = False) -> str: ...
def _format_callback(
func: object, args: Iterable[Any], kwargs: dict[str, Any], *, debug: bool = False, suffix: str = ""
) -> str: ...
else:
def _format_callback_source(func: object, args: Iterable[Any]) -> str: ...
def _format_args_and_kwargs(args: Iterable[Any], kwargs: dict[str, Any]) -> str: ...
def _format_callback(func: object, args: Iterable[Any], kwargs: dict[str, Any], suffix: str = "") -> str: ...
def extract_stack(f: FrameType | None = None, limit: int | None = None) -> traceback.StackSummary: ...

View File

@@ -10,13 +10,20 @@ if sys.version_info >= (3, 10):
else:
_LoopBoundMixin = object
__all__ = ("Queue", "PriorityQueue", "LifoQueue", "QueueFull", "QueueEmpty")
class QueueEmpty(Exception): ...
class QueueFull(Exception): ...
if sys.version_info >= (3, 13):
__all__ = ("Queue", "PriorityQueue", "LifoQueue", "QueueFull", "QueueEmpty", "QueueShutDown")
else:
__all__ = ("Queue", "PriorityQueue", "LifoQueue", "QueueFull", "QueueEmpty")
_T = TypeVar("_T")
if sys.version_info >= (3, 13):
class QueueShutDown(Exception): ...
# If Generic[_T] is last and _LoopBoundMixin is object, pyright is unhappy.
# We can remove the noqa pragma when dropping 3.9 support.
class Queue(Generic[_T], _LoopBoundMixin): # noqa: Y059
@@ -42,6 +49,8 @@ class Queue(Generic[_T], _LoopBoundMixin): # noqa: Y059
def task_done(self) -> None: ...
if sys.version_info >= (3, 9):
def __class_getitem__(cls, type: Any, /) -> GenericAlias: ...
if sys.version_info >= (3, 13):
def shutdown(self, immediate: bool = False) -> None: ...
class PriorityQueue(Queue[_T]): ...
class LifoQueue(Queue[_T]): ...

View File

@@ -2,6 +2,7 @@ import ssl
import sys
from _typeshed import ReadableBuffer, StrPath
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence, Sized
from types import ModuleType
from typing import Any, Protocol, SupportsIndex
from typing_extensions import Self, TypeAlias
@@ -130,7 +131,10 @@ class StreamWriter:
async def start_tls(
self, sslcontext: ssl.SSLContext, *, server_hostname: str | None = None, ssl_handshake_timeout: float | None = None
) -> None: ...
if sys.version_info >= (3, 11):
if sys.version_info >= (3, 13):
def __del__(self, warnings: ModuleType = ...) -> None: ...
elif sys.version_info >= (3, 11):
def __del__(self) -> None: ...
class StreamReader(AsyncIterator[bytes]):

View File

@@ -1,15 +1,55 @@
import sys
import types
from _typeshed import StrPath
from abc import ABCMeta, abstractmethod
from collections.abc import Callable
from socket import socket
from typing import Literal
from typing_extensions import Self, TypeVarTuple, Unpack, deprecated
from .base_events import Server, _ProtocolFactory, _SSLContext
from .events import AbstractEventLoop, BaseDefaultEventLoopPolicy
from .selector_events import BaseSelectorEventLoop
_Ts = TypeVarTuple("_Ts")
if sys.platform != "win32":
if sys.version_info >= (3, 14):
__all__ = ("SelectorEventLoop", "DefaultEventLoopPolicy", "EventLoop")
elif sys.version_info >= (3, 13):
__all__ = (
"SelectorEventLoop",
"AbstractChildWatcher",
"SafeChildWatcher",
"FastChildWatcher",
"PidfdChildWatcher",
"MultiLoopChildWatcher",
"ThreadedChildWatcher",
"DefaultEventLoopPolicy",
"EventLoop",
)
elif sys.version_info >= (3, 9):
__all__ = (
"SelectorEventLoop",
"AbstractChildWatcher",
"SafeChildWatcher",
"FastChildWatcher",
"PidfdChildWatcher",
"MultiLoopChildWatcher",
"ThreadedChildWatcher",
"DefaultEventLoopPolicy",
)
else:
__all__ = (
"SelectorEventLoop",
"AbstractChildWatcher",
"SafeChildWatcher",
"FastChildWatcher",
"MultiLoopChildWatcher",
"ThreadedChildWatcher",
"DefaultEventLoopPolicy",
)
# This is also technically not available on Win,
# but other parts of typeshed need this definition.
# So, it is special cased.
@@ -58,30 +98,6 @@ if sys.version_info < (3, 14):
def is_active(self) -> bool: ...
if sys.platform != "win32":
if sys.version_info >= (3, 14):
__all__ = ("SelectorEventLoop", "DefaultEventLoopPolicy")
elif sys.version_info >= (3, 9):
__all__ = (
"SelectorEventLoop",
"AbstractChildWatcher",
"SafeChildWatcher",
"FastChildWatcher",
"PidfdChildWatcher",
"MultiLoopChildWatcher",
"ThreadedChildWatcher",
"DefaultEventLoopPolicy",
)
else:
__all__ = (
"SelectorEventLoop",
"AbstractChildWatcher",
"SafeChildWatcher",
"FastChildWatcher",
"MultiLoopChildWatcher",
"ThreadedChildWatcher",
"DefaultEventLoopPolicy",
)
if sys.version_info < (3, 14):
if sys.version_info >= (3, 12):
# Doesn't actually have ABCMeta metaclass at runtime, but mypy complains if we don't have it in the stub.
@@ -141,7 +157,21 @@ if sys.platform != "win32":
) -> None: ...
def remove_child_handler(self, pid: int) -> bool: ...
class _UnixSelectorEventLoop(BaseSelectorEventLoop): ...
class _UnixSelectorEventLoop(BaseSelectorEventLoop):
if sys.version_info >= (3, 13):
async def create_unix_server( # type: ignore[override]
self,
protocol_factory: _ProtocolFactory,
path: StrPath | None = None,
*,
sock: socket | None = None,
backlog: int = 100,
ssl: _SSLContext = None,
ssl_handshake_timeout: float | None = None,
ssl_shutdown_timeout: float | None = None,
start_serving: bool = True,
cleanup_socket: bool = True,
) -> Server: ...
class _UnixDefaultEventLoopPolicy(BaseDefaultEventLoopPolicy):
if sys.version_info < (3, 14):
@@ -158,6 +188,9 @@ if sys.platform != "win32":
DefaultEventLoopPolicy = _UnixDefaultEventLoopPolicy
if sys.version_info >= (3, 13):
EventLoop = SelectorEventLoop
if sys.version_info < (3, 14):
if sys.version_info >= (3, 12):
@deprecated("Deprecated as of Python 3.12; will be removed in Python 3.14")

View File

@@ -7,14 +7,26 @@ from typing import IO, Any, ClassVar, Literal, NoReturn
from . import events, futures, proactor_events, selector_events, streams, windows_utils
if sys.platform == "win32":
__all__ = (
"SelectorEventLoop",
"ProactorEventLoop",
"IocpProactor",
"DefaultEventLoopPolicy",
"WindowsSelectorEventLoopPolicy",
"WindowsProactorEventLoopPolicy",
)
if sys.version_info >= (3, 13):
# 3.13 added `EventLoop`.
__all__ = (
"SelectorEventLoop",
"ProactorEventLoop",
"IocpProactor",
"DefaultEventLoopPolicy",
"WindowsSelectorEventLoopPolicy",
"WindowsProactorEventLoopPolicy",
"EventLoop",
)
else:
__all__ = (
"SelectorEventLoop",
"ProactorEventLoop",
"IocpProactor",
"DefaultEventLoopPolicy",
"WindowsSelectorEventLoopPolicy",
"WindowsProactorEventLoopPolicy",
)
NULL: Literal[0]
INFINITE: Literal[0xFFFFFFFF]
@@ -84,3 +96,5 @@ if sys.platform == "win32":
def set_child_watcher(self, watcher: Any) -> NoReturn: ...
DefaultEventLoopPolicy = WindowsSelectorEventLoopPolicy
if sys.version_info >= (3, 13):
EventLoop = ProactorEventLoop