mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-07 12:44:28 +08:00
Improve sqlite3 types (#7641)
Read through the code in CPython and made the types more precise where possible. Co-authored-by: Sebastian Rittau <srittau@rittau.biz>
This commit is contained in:
@@ -1,12 +1,23 @@
|
||||
import sqlite3
|
||||
import sys
|
||||
from _typeshed import ReadableBuffer, Self, StrOrBytesPath
|
||||
from _typeshed import ReadableBuffer, Self, StrOrBytesPath, SupportsLenAndGetItem
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator, Mapping
|
||||
from datetime import date, datetime, time
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, Generator, Iterable, Iterator, Protocol, TypeVar, overload
|
||||
from typing_extensions import Literal, final
|
||||
from typing import Any, Generic, Protocol, TypeVar, overload
|
||||
from typing_extensions import Literal, SupportsIndex, TypeAlias, final
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_SqliteData = str | bytes | int | float | None
|
||||
_T_co = TypeVar("_T_co", covariant=True)
|
||||
_CursorT = TypeVar("_CursorT", bound=Cursor)
|
||||
_SqliteData: TypeAlias = str | ReadableBuffer | int | float | None
|
||||
# Data that is passed through adapters can be of any type accepted by an adapter.
|
||||
_AdaptedInputData: TypeAlias = _SqliteData | Any
|
||||
# The Mapping must really be a dict, but making it invariant is too annoying.
|
||||
_Parameters: TypeAlias = SupportsLenAndGetItem[_AdaptedInputData] | Mapping[str, _AdaptedInputData]
|
||||
_SqliteOutputData: TypeAlias = str | bytes | int | float | None
|
||||
_Adapter: TypeAlias = Callable[[_T], _SqliteData]
|
||||
_Converter: TypeAlias = Callable[[bytes], Any]
|
||||
|
||||
paramstyle: str
|
||||
threadsafety: int
|
||||
@@ -81,43 +92,39 @@ if sys.version_info >= (3, 7):
|
||||
SQLITE_SELECT: int
|
||||
SQLITE_TRANSACTION: int
|
||||
SQLITE_UPDATE: int
|
||||
adapters: Any
|
||||
converters: Any
|
||||
adapters: dict[tuple[type[Any], type[Any]], _Adapter[Any]]
|
||||
converters: dict[str, _Converter]
|
||||
sqlite_version: str
|
||||
version: str
|
||||
|
||||
# TODO: adapt needs to get probed
|
||||
def adapt(obj, protocol, alternate): ...
|
||||
# Can take or return anything depending on what's in the registry.
|
||||
@overload
|
||||
def adapt(__obj: Any, __proto: Any) -> Any: ...
|
||||
@overload
|
||||
def adapt(__obj: Any, __proto: Any, __alt: _T) -> Any | _T: ...
|
||||
def complete_statement(statement: str) -> bool: ...
|
||||
|
||||
if sys.version_info >= (3, 7):
|
||||
def connect(
|
||||
database: StrOrBytesPath,
|
||||
timeout: float = ...,
|
||||
detect_types: int = ...,
|
||||
isolation_level: str | None = ...,
|
||||
check_same_thread: bool = ...,
|
||||
factory: type[Connection] | None = ...,
|
||||
cached_statements: int = ...,
|
||||
uri: bool = ...,
|
||||
) -> Connection: ...
|
||||
|
||||
_DatabaseArg: TypeAlias = StrOrBytesPath
|
||||
else:
|
||||
def connect(
|
||||
database: bytes | str,
|
||||
timeout: float = ...,
|
||||
detect_types: int = ...,
|
||||
isolation_level: str | None = ...,
|
||||
check_same_thread: bool = ...,
|
||||
factory: type[Connection] | None = ...,
|
||||
cached_statements: int = ...,
|
||||
uri: bool = ...,
|
||||
) -> Connection: ...
|
||||
_DatabaseArg: TypeAlias = bytes | str
|
||||
|
||||
def connect(
|
||||
database: _DatabaseArg,
|
||||
timeout: float = ...,
|
||||
detect_types: int = ...,
|
||||
isolation_level: str | None = ...,
|
||||
check_same_thread: bool = ...,
|
||||
factory: type[Connection] | None = ...,
|
||||
cached_statements: int = ...,
|
||||
uri: bool = ...,
|
||||
) -> Connection: ...
|
||||
def enable_callback_tracebacks(__enable: bool) -> None: ...
|
||||
|
||||
# takes a pos-or-keyword argument because there is a C wrapper
|
||||
def enable_shared_cache(enable: int) -> None: ...
|
||||
def register_adapter(__type: type[_T], __caster: Callable[[_T], int | float | str | bytes]) -> None: ...
|
||||
def register_converter(__name: str, __converter: Callable[[bytes], Any]) -> None: ...
|
||||
def register_adapter(__type: type[_T], __caster: _Adapter[_T]) -> None: ...
|
||||
def register_converter(__name: str, __converter: _Converter) -> None: ...
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
class Cache:
|
||||
@@ -126,7 +133,7 @@ if sys.version_info < (3, 8):
|
||||
def get(self, *args, **kwargs) -> None: ...
|
||||
|
||||
class _AggregateProtocol(Protocol):
|
||||
def step(self, value: int) -> object: ...
|
||||
def step(self, __value: int) -> object: ...
|
||||
def finalize(self) -> int: ...
|
||||
|
||||
class _SingleParamWindowAggregateClass(Protocol):
|
||||
@@ -148,22 +155,44 @@ class _WindowAggregateClass(Protocol):
|
||||
def finalize(self) -> _SqliteData: ...
|
||||
|
||||
class Connection:
|
||||
DataError: Any
|
||||
DatabaseError: Any
|
||||
Error: Any
|
||||
IntegrityError: Any
|
||||
InterfaceError: Any
|
||||
InternalError: Any
|
||||
NotSupportedError: Any
|
||||
OperationalError: Any
|
||||
ProgrammingError: Any
|
||||
Warning: Any
|
||||
in_transaction: Any
|
||||
isolation_level: Any
|
||||
@property
|
||||
def DataError(self) -> type[sqlite3.DataError]: ...
|
||||
@property
|
||||
def DatabaseError(self) -> type[sqlite3.DatabaseError]: ...
|
||||
@property
|
||||
def Error(self) -> type[sqlite3.Error]: ...
|
||||
@property
|
||||
def IntegrityError(self) -> type[sqlite3.IntegrityError]: ...
|
||||
@property
|
||||
def InterfaceError(self) -> type[sqlite3.InterfaceError]: ...
|
||||
@property
|
||||
def InternalError(self) -> type[sqlite3.InternalError]: ...
|
||||
@property
|
||||
def NotSupportedError(self) -> type[sqlite3.NotSupportedError]: ...
|
||||
@property
|
||||
def OperationalError(self) -> type[sqlite3.OperationalError]: ...
|
||||
@property
|
||||
def ProgrammingError(self) -> type[sqlite3.ProgrammingError]: ...
|
||||
@property
|
||||
def Warning(self) -> type[sqlite3.Warning]: ...
|
||||
@property
|
||||
def in_transaction(self) -> bool: ...
|
||||
isolation_level: str | None # one of '', 'DEFERRED', 'IMMEDIATE' or 'EXCLUSIVE'
|
||||
@property
|
||||
def total_changes(self) -> int: ...
|
||||
row_factory: Any
|
||||
text_factory: Any
|
||||
total_changes: Any
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
def __init__(
|
||||
self,
|
||||
database: _DatabaseArg,
|
||||
timeout: float = ...,
|
||||
detect_types: int = ...,
|
||||
isolation_level: str | None = ...,
|
||||
check_same_thread: bool = ...,
|
||||
factory: type[Connection] | None = ...,
|
||||
cached_statements: int = ...,
|
||||
uri: bool = ...,
|
||||
) -> None: ...
|
||||
def close(self) -> None: ...
|
||||
if sys.version_info >= (3, 11):
|
||||
def blobopen(self, __table: str, __column: str, __row: int, *, readonly: bool = ..., name: str = ...) -> Blob: ...
|
||||
@@ -187,17 +216,21 @@ class Connection:
|
||||
self, __name: str, __num_params: int, __aggregate_class: Callable[[], _WindowAggregateClass] | None
|
||||
) -> None: ...
|
||||
|
||||
def create_collation(self, __name: str, __callback: Any) -> None: ...
|
||||
def create_collation(self, __name: str, __callback: Callable[[str, str], int | SupportsIndex] | None) -> None: ...
|
||||
if sys.version_info >= (3, 8):
|
||||
def create_function(self, name: str, narg: int, func: Any, *, deterministic: bool = ...) -> None: ...
|
||||
def create_function(
|
||||
self, name: str, narg: int, func: Callable[..., _SqliteData], *, deterministic: bool = ...
|
||||
) -> None: ...
|
||||
else:
|
||||
def create_function(self, name: str, num_params: int, func: Any) -> None: ...
|
||||
def create_function(self, name: str, num_params: int, func: Callable[..., _SqliteData]) -> None: ...
|
||||
|
||||
def cursor(self, cursorClass: type | None = ...) -> Cursor: ...
|
||||
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Cursor: ...
|
||||
# TODO: please check in executemany() if seq_of_parameters type is possible like this
|
||||
def executemany(self, __sql: str, __parameters: Iterable[Iterable[Any]]) -> Cursor: ...
|
||||
def executescript(self, __sql_script: bytes | str) -> Cursor: ...
|
||||
@overload
|
||||
def cursor(self, cursorClass: None = ...) -> Cursor: ...
|
||||
@overload
|
||||
def cursor(self, cursorClass: Callable[[], _CursorT]) -> _CursorT: ...
|
||||
def execute(self, sql: str, parameters: _Parameters = ...) -> Cursor: ...
|
||||
def executemany(self, __sql: str, __parameters: Iterable[_Parameters]) -> Cursor: ...
|
||||
def executescript(self, __sql_script: str) -> Cursor: ...
|
||||
def interrupt(self) -> None: ...
|
||||
def iterdump(self) -> Generator[str, None, None]: ...
|
||||
def rollback(self) -> None: ...
|
||||
@@ -208,8 +241,8 @@ class Connection:
|
||||
def set_trace_callback(self, trace_callback: Callable[[str], object] | None) -> None: ...
|
||||
# enable_load_extension and load_extension is not available on python distributions compiled
|
||||
# without sqlite3 loadable extension support. see footnotes https://docs.python.org/3/library/sqlite3.html#f1
|
||||
def enable_load_extension(self, enabled: bool) -> None: ...
|
||||
def load_extension(self, path: str) -> None: ...
|
||||
def enable_load_extension(self, __enabled: bool) -> None: ...
|
||||
def load_extension(self, __name: str) -> None: ...
|
||||
if sys.version_info >= (3, 7):
|
||||
def backup(
|
||||
self,
|
||||
@@ -226,29 +259,32 @@ class Connection:
|
||||
def serialize(self, *, name: str = ...) -> bytes: ...
|
||||
def deserialize(self, __data: ReadableBuffer, *, name: str = ...) -> None: ...
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
def __call__(self, __sql: str) -> _Statement: ...
|
||||
def __enter__(self: Self) -> Self: ...
|
||||
def __exit__(
|
||||
self, __type: type[BaseException] | None, __value: BaseException | None, __traceback: TracebackType | None
|
||||
) -> Literal[False]: ...
|
||||
|
||||
class Cursor(Iterator[Any]):
|
||||
arraysize: Any
|
||||
connection: Any
|
||||
description: Any
|
||||
lastrowid: Any
|
||||
row_factory: Any
|
||||
rowcount: int
|
||||
# TODO: Cursor class accepts exactly 1 argument
|
||||
# required type is sqlite3.Connection (which is imported as _Connection)
|
||||
# however, the name of the __init__ variable is unknown
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
arraysize: int
|
||||
@property
|
||||
def connection(self) -> Connection: ...
|
||||
@property
|
||||
def description(self) -> tuple[tuple[str, None, None, None, None, None, None], ...] | None: ...
|
||||
@property
|
||||
def lastrowid(self) -> int | None: ...
|
||||
row_factory: Callable[[Cursor, Row[Any]], object] | None
|
||||
@property
|
||||
def rowcount(self) -> int: ...
|
||||
def __init__(self, __cursor: Connection) -> None: ...
|
||||
def close(self) -> None: ...
|
||||
def execute(self, __sql: str, __parameters: Iterable[Any] = ...) -> Cursor: ...
|
||||
def executemany(self, __sql: str, __seq_of_parameters: Iterable[Iterable[Any]]) -> Cursor: ...
|
||||
def executescript(self, __sql_script: bytes | str) -> Cursor: ...
|
||||
def execute(self: Self, __sql: str, __parameters: _Parameters = ...) -> Self: ...
|
||||
def executemany(self: Self, __sql: str, __seq_of_parameters: Iterable[_Parameters]) -> Self: ...
|
||||
def executescript(self, __sql_script: str) -> Cursor: ...
|
||||
def fetchall(self) -> list[Any]: ...
|
||||
def fetchmany(self, size: int | None = ...) -> list[Any]: ...
|
||||
# Returns either a row (as created by the row_factory) or None, but
|
||||
# putting None in the return annotation causes annoying false positives.
|
||||
def fetchone(self) -> Any: ...
|
||||
def setinputsizes(self, __sizes: object) -> None: ... # does nothing
|
||||
def setoutputsize(self, __size: object, __column: object = ...) -> None: ... # does nothing
|
||||
@@ -273,28 +309,37 @@ OptimizedUnicode = str
|
||||
|
||||
@final
|
||||
class PrepareProtocol:
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
def __init__(self, *args: object, **kwargs: object) -> None: ...
|
||||
|
||||
class ProgrammingError(DatabaseError): ...
|
||||
|
||||
class Row:
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
def keys(self): ...
|
||||
def __eq__(self, __other): ...
|
||||
def __ge__(self, __other): ...
|
||||
def __getitem__(self, __index): ...
|
||||
def __gt__(self, __other): ...
|
||||
def __hash__(self): ...
|
||||
def __iter__(self): ...
|
||||
def __le__(self, __other): ...
|
||||
def __len__(self): ...
|
||||
def __lt__(self, __other): ...
|
||||
def __ne__(self, __other): ...
|
||||
class Row(Generic[_T_co]):
|
||||
def __init__(self, __cursor: Cursor, __data: tuple[_T_co, ...]) -> None: ...
|
||||
def keys(self) -> list[str]: ...
|
||||
@overload
|
||||
def __getitem__(self, __index: int | str) -> _T_co: ...
|
||||
@overload
|
||||
def __getitem__(self, __index: slice) -> tuple[_T_co, ...]: ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __iter__(self) -> Iterator[_T_co]: ...
|
||||
def __len__(self) -> int: ...
|
||||
# These return NotImplemented for anything that is not a Row.
|
||||
def __eq__(self, __other: object) -> bool: ...
|
||||
def __ge__(self, __other: object) -> bool: ...
|
||||
def __gt__(self, __other: object) -> bool: ...
|
||||
def __le__(self, __other: object) -> bool: ...
|
||||
def __lt__(self, __other: object) -> bool: ...
|
||||
def __ne__(self, __other: object) -> bool: ...
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
if sys.version_info >= (3, 8):
|
||||
@final
|
||||
class _Statement: ...
|
||||
|
||||
else:
|
||||
@final
|
||||
class Statement:
|
||||
def __init__(self, *args, **kwargs): ...
|
||||
_Statement: TypeAlias = Statement
|
||||
|
||||
class Warning(Exception): ...
|
||||
|
||||
|
||||
Reference in New Issue
Block a user