diff --git a/stdlib/logging/__init__.pyi b/stdlib/logging/__init__.pyi index 59822ab3c..135d64a38 100644 --- a/stdlib/logging/__init__.pyi +++ b/stdlib/logging/__init__.pyi @@ -706,7 +706,7 @@ else: ) -> None: ... def shutdown(handlerList: Sequence[Any] = ...) -> None: ... # handlerList is undocumented -def setLoggerClass(klass: type) -> None: ... +def setLoggerClass(klass: Type[Logger]) -> None: ... def captureWarnings(capture: bool) -> None: ... def setLogRecordFactory(factory: Callable[..., LogRecord]) -> None: ... diff --git a/stdlib/sqlite3/dbapi2.pyi b/stdlib/sqlite3/dbapi2.pyi index 51def5406..0eaab5eff 100644 --- a/stdlib/sqlite3/dbapi2.pyi +++ b/stdlib/sqlite3/dbapi2.pyi @@ -1,7 +1,7 @@ import os import sys from datetime import date, datetime, time -from typing import Any, Callable, Generator, Iterable, Iterator, List, Optional, Text, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Generator, Iterable, Iterator, List, Optional, Protocol, Text, Tuple, Type, TypeVar, Union _T = TypeVar("_T") @@ -113,6 +113,10 @@ if sys.version_info < (3, 8): def display(self, *args, **kwargs) -> None: ... def get(self, *args, **kwargs) -> None: ... +class _AggregateProtocol(Protocol): + def step(self, value: int) -> None: ... + def finalize(self) -> int: ... + class Connection(object): DataError: Any DatabaseError: Any @@ -132,7 +136,7 @@ class Connection(object): def __init__(self, *args: Any, **kwargs: Any) -> None: ... def close(self) -> None: ... def commit(self) -> None: ... - def create_aggregate(self, name: str, num_params: int, aggregate_class: type) -> None: ... + def create_aggregate(self, name: str, num_params: int, aggregate_class: Callable[[], _AggregateProtocol]) -> None: ... def create_collation(self, name: str, callable: Any) -> None: ... if sys.version_info >= (3, 8): def create_function(self, name: str, num_params: int, func: Any, *, deterministic: bool = ...) -> None: ...