sqlite3: handle return-type with factory argument. (#11571)

This commit is contained in:
shawnbrown
2024-03-13 04:14:08 -04:00
committed by GitHub
parent 8558307188
commit ea6dac40d4
2 changed files with 93 additions and 16 deletions

View File

@@ -8,6 +8,7 @@ from typing import Any, Literal, Protocol, SupportsIndex, TypeVar, final, overlo
from typing_extensions import Self, TypeAlias
_T = TypeVar("_T")
_ConnectionT = TypeVar("_ConnectionT", bound=Connection)
_CursorT = TypeVar("_CursorT", bound=Cursor)
_SqliteData: TypeAlias = str | ReadableBuffer | int | float | None
# Data that is passed through adapters can be of any type accepted by an adapter.
@@ -223,29 +224,79 @@ def adapt(obj: Any, proto: Any, alt: _T, /) -> Any | _T: ...
def complete_statement(statement: str) -> bool: ...
if sys.version_info >= (3, 12):
@overload
def connect(
database: StrOrBytesPath,
timeout: float = ...,
detect_types: int = ...,
isolation_level: str | None = ...,
check_same_thread: bool = ...,
factory: type[Connection] | None = ...,
cached_statements: int = ...,
uri: bool = ...,
timeout: float = 5.0,
detect_types: int = 0,
isolation_level: Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | None = "DEFERRED",
check_same_thread: bool = True,
cached_statements: int = 128,
uri: bool = False,
*,
autocommit: bool = ...,
) -> Connection: ...
else:
@overload
def connect(
database: StrOrBytesPath,
timeout: float = ...,
detect_types: int = ...,
isolation_level: str | None = ...,
check_same_thread: bool = ...,
factory: type[Connection] | None = ...,
cached_statements: int = ...,
uri: bool = ...,
timeout: float,
detect_types: int,
isolation_level: Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | None,
check_same_thread: bool,
factory: type[_ConnectionT],
cached_statements: int = 128,
uri: bool = False,
*,
autocommit: bool = ...,
) -> _ConnectionT: ...
@overload
def connect(
database: StrOrBytesPath,
timeout: float = 5.0,
detect_types: int = 0,
isolation_level: Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | None = "DEFERRED",
check_same_thread: bool = True,
*,
factory: type[_ConnectionT],
cached_statements: int = 128,
uri: bool = False,
autocommit: bool = ...,
) -> _ConnectionT: ...
else:
@overload
def connect(
database: StrOrBytesPath,
timeout: float = 5.0,
detect_types: int = 0,
isolation_level: Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | None = "DEFERRED",
check_same_thread: bool = True,
cached_statements: int = 128,
uri: bool = False,
) -> Connection: ...
@overload
def connect(
database: StrOrBytesPath,
timeout: float,
detect_types: int,
isolation_level: Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | None,
check_same_thread: bool,
factory: type[_ConnectionT],
cached_statements: int = 128,
uri: bool = False,
) -> _ConnectionT: ...
@overload
def connect(
database: StrOrBytesPath,
timeout: float = 5.0,
detect_types: int = 0,
isolation_level: Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | None = "DEFERRED",
check_same_thread: bool = True,
*,
factory: type[_ConnectionT],
cached_statements: int = 128,
uri: bool = False,
) -> _ConnectionT: ...
def enable_callback_tracebacks(enable: bool, /) -> None: ...

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
import sqlite3
from typing_extensions import assert_type
class MyConnection(sqlite3.Connection):
pass
# Default return-type is Connection.
assert_type(sqlite3.connect(":memory:"), sqlite3.Connection)
# Providing an alternate factory changes the return-type.
assert_type(sqlite3.connect(":memory:", factory=MyConnection), MyConnection)
# Provides a true positive error. When checking the connect() function,
# mypy should report an arg-type error for the factory argument.
with sqlite3.connect(":memory:", factory=None) as con: # type: ignore
pass
# The Connection class also accepts a `factory` arg but it does not affect
# the return-type. This use case is not idiomatic--connections should be
# established using the `connect()` function, not directly (as shown here).
assert_type(sqlite3.Connection(":memory:", factory=None), sqlite3.Connection)
assert_type(sqlite3.Connection(":memory:", factory=MyConnection), sqlite3.Connection)