Improve sqlite3 types (#7641)

Read through the code in CPython and made the types more precise
where possible.

Co-authored-by: Sebastian Rittau <srittau@rittau.biz>
This commit is contained in:
Jelle Zijlstra
2022-04-16 11:20:11 -07:00
committed by GitHub
parent 2e98c8284c
commit b0611bc031
2 changed files with 130 additions and 90 deletions

View File

@@ -1,12 +1,23 @@
import sqlite3
import sys
from _typeshed import ReadableBuffer, Self, StrOrBytesPath
from _typeshed import ReadableBuffer, Self, StrOrBytesPath, SupportsLenAndGetItem
from collections.abc import Callable, Generator, Iterable, Iterator, Mapping
from datetime import date, datetime, time
from types import TracebackType
from typing import Any, Callable, Generator, Iterable, Iterator, Protocol, TypeVar, overload
from typing_extensions import Literal, final
from typing import Any, Generic, Protocol, TypeVar, overload
from typing_extensions import Literal, SupportsIndex, TypeAlias, final
_T = TypeVar("_T")
_SqliteData = str | bytes | int | float | None
_T_co = TypeVar("_T_co", covariant=True)
_CursorT = TypeVar("_CursorT", bound=Cursor)
_SqliteData: TypeAlias = str | ReadableBuffer | int | float | None
# Data that is passed through adapters can be of any type accepted by an adapter.
_AdaptedInputData: TypeAlias = _SqliteData | Any
# The Mapping must really be a dict, but making it invariant is too annoying.
_Parameters: TypeAlias = SupportsLenAndGetItem[_AdaptedInputData] | Mapping[str, _AdaptedInputData]
_SqliteOutputData: TypeAlias = str | bytes | int | float | None
_Adapter: TypeAlias = Callable[[_T], _SqliteData]
_Converter: TypeAlias = Callable[[bytes], Any]
paramstyle: str
threadsafety: int
@@ -81,43 +92,39 @@ if sys.version_info >= (3, 7):
SQLITE_SELECT: int
SQLITE_TRANSACTION: int
SQLITE_UPDATE: int
adapters: Any
converters: Any
adapters: dict[tuple[type[Any], type[Any]], _Adapter[Any]]
converters: dict[str, _Converter]
sqlite_version: str
version: str
# TODO: adapt needs to get probed
def adapt(obj, protocol, alternate): ...
# Can take or return anything depending on what's in the registry.
@overload
def adapt(__obj: Any, __proto: Any) -> Any: ...
@overload
def adapt(__obj: Any, __proto: Any, __alt: _T) -> Any | _T: ...
def complete_statement(statement: str) -> bool: ...
if sys.version_info >= (3, 7):
def connect(
database: StrOrBytesPath,
timeout: float = ...,
detect_types: int = ...,
isolation_level: str | None = ...,
check_same_thread: bool = ...,
factory: type[Connection] | None = ...,
cached_statements: int = ...,
uri: bool = ...,
) -> Connection: ...
_DatabaseArg: TypeAlias = StrOrBytesPath
else:
def connect(
database: bytes | str,
timeout: float = ...,
detect_types: int = ...,
isolation_level: str | None = ...,
check_same_thread: bool = ...,
factory: type[Connection] | None = ...,
cached_statements: int = ...,
uri: bool = ...,
) -> Connection: ...
_DatabaseArg: TypeAlias = bytes | str
def connect(
database: _DatabaseArg,
timeout: float = ...,
detect_types: int = ...,
isolation_level: str | None = ...,
check_same_thread: bool = ...,
factory: type[Connection] | None = ...,
cached_statements: int = ...,
uri: bool = ...,
) -> Connection: ...
def enable_callback_tracebacks(__enable: bool) -> None: ...
# takes a pos-or-keyword argument because there is a C wrapper
def enable_shared_cache(enable: int) -> None: ...
def register_adapter(__type: type[_T], __caster: Callable[[_T], int | float | str | bytes]) -> None: ...
def register_converter(__name: str, __converter: Callable[[bytes], Any]) -> None: ...
def register_adapter(__type: type[_T], __caster: _Adapter[_T]) -> None: ...
def register_converter(__name: str, __converter: _Converter) -> None: ...
if sys.version_info < (3, 8):
class Cache:
@@ -126,7 +133,7 @@ if sys.version_info < (3, 8):
def get(self, *args, **kwargs) -> None: ...
class _AggregateProtocol(Protocol):
def step(self, value: int) -> object: ...
def step(self, __value: int) -> object: ...
def finalize(self) -> int: ...
class _SingleParamWindowAggregateClass(Protocol):
@@ -148,22 +155,44 @@ class _WindowAggregateClass(Protocol):
def finalize(self) -> _SqliteData: ...
class Connection:
DataError: Any
DatabaseError: Any
Error: Any
IntegrityError: Any
InterfaceError: Any
InternalError: Any
NotSupportedError: Any
OperationalError: Any
ProgrammingError: Any
Warning: Any
in_transaction: Any
isolation_level: Any
@property
def DataError(self) -> type[sqlite3.DataError]: ...
@property
def DatabaseError(self) -> type[sqlite3.DatabaseError]: ...
@property
def Error(self) -> type[sqlite3.Error]: ...
@property
def IntegrityError(self) -> type[sqlite3.IntegrityError]: ...
@property
def InterfaceError(self) -> type[sqlite3.InterfaceError]: ...
@property
def InternalError(self) -> type[sqlite3.InternalError]: ...
@property
def NotSupportedError(self) -> type[sqlite3.NotSupportedError]: ...
@property
def OperationalError(self) -> type[sqlite3.OperationalError]: ...
@property
def ProgrammingError(self) -> type[sqlite3.ProgrammingError]: ...
@property
def Warning(self) -> type[sqlite3.Warning]: ...
@property
def in_transaction(self) -> bool: ...
isolation_level: str | None # one of '', 'DEFERRED', 'IMMEDIATE' or 'EXCLUSIVE'
@property
def total_changes(self) -> int: ...
row_factory: Any
text_factory: Any
total_changes: Any
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def __init__(
self,
database: _DatabaseArg,
timeout: float = ...,
detect_types: int = ...,
isolation_level: str | None = ...,
check_same_thread: bool = ...,
factory: type[Connection] | None = ...,
cached_statements: int = ...,
uri: bool = ...,
) -> None: ...
def close(self) -> None: ...
if sys.version_info >= (3, 11):
def blobopen(self, __table: str, __column: str, __row: int, *, readonly: bool = ..., name: str = ...) -> Blob: ...
@@ -187,17 +216,21 @@ class Connection:
self, __name: str, __num_params: int, __aggregate_class: Callable[[], _WindowAggregateClass] | None
) -> None: ...
def create_collation(self, __name: str, __callback: Any) -> None: ...
def create_collation(self, __name: str, __callback: Callable[[str, str], int | SupportsIndex] | None) -> None: ...
if sys.version_info >= (3, 8):
def create_function(self, name: str, narg: int, func: Any, *, deterministic: bool = ...) -> None: ...
def create_function(
self, name: str, narg: int, func: Callable[..., _SqliteData], *, deterministic: bool = ...
) -> None: ...
else:
def create_function(self, name: str, num_params: int, func: Any) -> None: ...
def create_function(self, name: str, num_params: int, func: Callable[..., _SqliteData]) -> None: ...
def cursor(self, cursorClass: type | None = ...) -> Cursor: ...
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Cursor: ...
# TODO: please check in executemany() if seq_of_parameters type is possible like this
def executemany(self, __sql: str, __parameters: Iterable[Iterable[Any]]) -> Cursor: ...
def executescript(self, __sql_script: bytes | str) -> Cursor: ...
@overload
def cursor(self, cursorClass: None = ...) -> Cursor: ...
@overload
def cursor(self, cursorClass: Callable[[], _CursorT]) -> _CursorT: ...
def execute(self, sql: str, parameters: _Parameters = ...) -> Cursor: ...
def executemany(self, __sql: str, __parameters: Iterable[_Parameters]) -> Cursor: ...
def executescript(self, __sql_script: str) -> Cursor: ...
def interrupt(self) -> None: ...
def iterdump(self) -> Generator[str, None, None]: ...
def rollback(self) -> None: ...
@@ -208,8 +241,8 @@ class Connection:
def set_trace_callback(self, trace_callback: Callable[[str], object] | None) -> None: ...
# enable_load_extension and load_extension is not available on python distributions compiled
# without sqlite3 loadable extension support. see footnotes https://docs.python.org/3/library/sqlite3.html#f1
def enable_load_extension(self, enabled: bool) -> None: ...
def load_extension(self, path: str) -> None: ...
def enable_load_extension(self, __enabled: bool) -> None: ...
def load_extension(self, __name: str) -> None: ...
if sys.version_info >= (3, 7):
def backup(
self,
@@ -226,29 +259,32 @@ class Connection:
def serialize(self, *, name: str = ...) -> bytes: ...
def deserialize(self, __data: ReadableBuffer, *, name: str = ...) -> None: ...
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
def __call__(self, __sql: str) -> _Statement: ...
def __enter__(self: Self) -> Self: ...
def __exit__(
self, __type: type[BaseException] | None, __value: BaseException | None, __traceback: TracebackType | None
) -> Literal[False]: ...
class Cursor(Iterator[Any]):
arraysize: Any
connection: Any
description: Any
lastrowid: Any
row_factory: Any
rowcount: int
# TODO: Cursor class accepts exactly 1 argument
# required type is sqlite3.Connection (which is imported as _Connection)
# however, the name of the __init__ variable is unknown
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
arraysize: int
@property
def connection(self) -> Connection: ...
@property
def description(self) -> tuple[tuple[str, None, None, None, None, None, None], ...] | None: ...
@property
def lastrowid(self) -> int | None: ...
row_factory: Callable[[Cursor, Row[Any]], object] | None
@property
def rowcount(self) -> int: ...
def __init__(self, __cursor: Connection) -> None: ...
def close(self) -> None: ...
def execute(self, __sql: str, __parameters: Iterable[Any] = ...) -> Cursor: ...
def executemany(self, __sql: str, __seq_of_parameters: Iterable[Iterable[Any]]) -> Cursor: ...
def executescript(self, __sql_script: bytes | str) -> Cursor: ...
def execute(self: Self, __sql: str, __parameters: _Parameters = ...) -> Self: ...
def executemany(self: Self, __sql: str, __seq_of_parameters: Iterable[_Parameters]) -> Self: ...
def executescript(self, __sql_script: str) -> Cursor: ...
def fetchall(self) -> list[Any]: ...
def fetchmany(self, size: int | None = ...) -> list[Any]: ...
# Returns either a row (as created by the row_factory) or None, but
# putting None in the return annotation causes annoying false positives.
def fetchone(self) -> Any: ...
def setinputsizes(self, __sizes: object) -> None: ... # does nothing
def setoutputsize(self, __size: object, __column: object = ...) -> None: ... # does nothing
@@ -273,28 +309,37 @@ OptimizedUnicode = str
@final
class PrepareProtocol:
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def __init__(self, *args: object, **kwargs: object) -> None: ...
class ProgrammingError(DatabaseError): ...
class Row:
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def keys(self): ...
def __eq__(self, __other): ...
def __ge__(self, __other): ...
def __getitem__(self, __index): ...
def __gt__(self, __other): ...
def __hash__(self): ...
def __iter__(self): ...
def __le__(self, __other): ...
def __len__(self): ...
def __lt__(self, __other): ...
def __ne__(self, __other): ...
class Row(Generic[_T_co]):
def __init__(self, __cursor: Cursor, __data: tuple[_T_co, ...]) -> None: ...
def keys(self) -> list[str]: ...
@overload
def __getitem__(self, __index: int | str) -> _T_co: ...
@overload
def __getitem__(self, __index: slice) -> tuple[_T_co, ...]: ...
def __hash__(self) -> int: ...
def __iter__(self) -> Iterator[_T_co]: ...
def __len__(self) -> int: ...
# These return NotImplemented for anything that is not a Row.
def __eq__(self, __other: object) -> bool: ...
def __ge__(self, __other: object) -> bool: ...
def __gt__(self, __other: object) -> bool: ...
def __le__(self, __other: object) -> bool: ...
def __lt__(self, __other: object) -> bool: ...
def __ne__(self, __other: object) -> bool: ...
if sys.version_info < (3, 8):
if sys.version_info >= (3, 8):
@final
class _Statement: ...
else:
@final
class Statement:
def __init__(self, *args, **kwargs): ...
_Statement: TypeAlias = Statement
class Warning(Exception): ...