diff --git a/stubs/SQLAlchemy/@tests/stubtest_allowlist.txt b/stubs/SQLAlchemy/@tests/stubtest_allowlist.txt index 7234e1eb1..24dedb6a5 100644 --- a/stubs/SQLAlchemy/@tests/stubtest_allowlist.txt +++ b/stubs/SQLAlchemy/@tests/stubtest_allowlist.txt @@ -1,3 +1,6 @@ +# stub-only module +sqlalchemy.dbapi + # wrong argument name in implementation ("self" instead of "cls") sqlalchemy.engine.URL.__new__ sqlalchemy.engine.url.URL.__new__ @@ -25,6 +28,8 @@ sqlalchemy.testing.util.resolve_lambda sqlalchemy.util.WeakSequence.__init__ # not always present +sqlalchemy.engine.Engine.logging_name # initialized if not None +sqlalchemy.engine.base.Engine.logging_name # initialized if not None sqlalchemy.testing.util.non_refcount_gc_collect # replaced at runtime @@ -103,6 +108,20 @@ sqlalchemy.orm.strategy_options.Load.undefer sqlalchemy.orm.strategy_options.Load.undefer_group sqlalchemy.orm.strategy_options.Load.with_expression +# abstract fields not present at runtime +sqlalchemy.engine.Transaction.connection +sqlalchemy.engine.Transaction.is_active +sqlalchemy.engine.base.Transaction.connection +sqlalchemy.engine.base.Transaction.is_active + +# initialized to None during class construction, but overridden during __init__() +sqlalchemy.engine.Connection.engine +sqlalchemy.engine.base.Connection.engine + +# uses @memoized_property at runtime, but we use @property for compatibility +sqlalchemy.engine.URL.normalized_query +sqlalchemy.engine.url.URL.normalized_query + # unclear problems sqlalchemy.sql.elements.quoted_name.lower sqlalchemy.sql.elements.quoted_name.upper diff --git a/stubs/SQLAlchemy/sqlalchemy/cimmutabledict.pyi b/stubs/SQLAlchemy/sqlalchemy/cimmutabledict.pyi index 851872094..e3f87af28 100644 --- a/stubs/SQLAlchemy/sqlalchemy/cimmutabledict.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/cimmutabledict.pyi @@ -1,12 +1,17 @@ -class immutabledict: - def __len__(self) -> int: ... - def __getitem__(self, __item): ... - def __iter__(self): ... - def union(self, **kwargs): ... - def merge_with(self, *args): ... - def keys(self): ... - def __contains__(self, __item): ... - def items(self): ... - def values(self): ... - def get(self, __key, __default=...): ... - def __reduce__(self): ... +from _typeshed import SupportsKeysAndGetItem +from collections.abc import Iterable +from typing import Generic, TypeVar, overload + +_KT = TypeVar("_KT") +_KT2 = TypeVar("_KT2") +_VT = TypeVar("_VT") +_VT2 = TypeVar("_VT2") + +class immutabledict(dict[_KT, _VT], Generic[_KT, _VT]): + @overload + def union(self, __dict: dict[_KT2, _VT2]) -> immutabledict[_KT | _KT2, _VT | _VT2]: ... + @overload + def union(self, __dict: None = ..., **kw: SupportsKeysAndGetItem[_KT2, _VT2]) -> immutabledict[_KT | _KT2, _VT | _VT2]: ... + def merge_with( + self, *args: SupportsKeysAndGetItem[_KT | _KT2, _VT2] | Iterable[tuple[_KT2, _VT2]] | None + ) -> immutabledict[_KT | _KT2, _VT | _VT2]: ... diff --git a/stubs/SQLAlchemy/sqlalchemy/dbapi.pyi b/stubs/SQLAlchemy/sqlalchemy/dbapi.pyi new file mode 100644 index 000000000..432e5936a --- /dev/null +++ b/stubs/SQLAlchemy/sqlalchemy/dbapi.pyi @@ -0,0 +1,36 @@ +# TODO: Tempory copy of _typeshed.dbapi, until that file is available in all typecheckers. +# Does not exist at runtime. + +from collections.abc import Mapping, Sequence +from typing import Any, Protocol + +DBAPITypeCode = Any | None +# Strictly speaking, this should be a Sequence, but the type system does +# not support fixed-length sequences. +DBAPIColumnDescription = tuple[str, DBAPITypeCode, int | None, int | None, int | None, int | None, bool | None] + +class DBAPIConnection(Protocol): + def close(self) -> object: ... + def commit(self) -> object: ... + # optional: + # def rollback(self) -> Any: ... + def cursor(self) -> DBAPICursor: ... + +class DBAPICursor(Protocol): + @property + def description(self) -> Sequence[DBAPIColumnDescription] | None: ... + @property + def rowcount(self) -> int: ... + # optional: + # def callproc(self, __procname: str, __parameters: Sequence[Any] = ...) -> Sequence[Any]: ... + def close(self) -> object: ... + def execute(self, __operation: str, __parameters: Sequence[Any] | Mapping[str, Any] = ...) -> object: ... + def executemany(self, __operation: str, __seq_of_parameters: Sequence[Sequence[Any]]) -> object: ... + def fetchone(self) -> Sequence[Any] | None: ... + def fetchmany(self, __size: int = ...) -> Sequence[Sequence[Any]]: ... + def fetchall(self) -> Sequence[Sequence[Any]]: ... + # optional: + # def nextset(self) -> None | Literal[True]: ... + arraysize: int + def setinputsizes(self, __sizes: Sequence[DBAPITypeCode | int | None]) -> object: ... + def setoutputsize(self, __size: int, __column: int = ...) -> object: ... diff --git a/stubs/SQLAlchemy/sqlalchemy/dialects/mysql/enumerated.pyi b/stubs/SQLAlchemy/sqlalchemy/dialects/mysql/enumerated.pyi index ba9343513..e68dcdfdf 100644 --- a/stubs/SQLAlchemy/sqlalchemy/dialects/mysql/enumerated.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/dialects/mysql/enumerated.pyi @@ -3,7 +3,7 @@ from typing import Any from ...sql import sqltypes from .types import _StringType -class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType): # type: ignore[misc] +class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType): # type: ignore # incompatible with base class __visit_name__: str native_enum: bool def __init__(self, *enums, **kw) -> None: ... diff --git a/stubs/SQLAlchemy/sqlalchemy/dialects/mysql/types.pyi b/stubs/SQLAlchemy/sqlalchemy/dialects/mysql/types.pyi index 402a19aa5..fecd364f6 100644 --- a/stubs/SQLAlchemy/sqlalchemy/dialects/mysql/types.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/dialects/mysql/types.pyi @@ -32,7 +32,7 @@ class _StringType(sqltypes.String): **kw, ) -> None: ... -class _MatchType(sqltypes.Float, sqltypes.MatchType): # type: ignore[misc] +class _MatchType(sqltypes.Float, sqltypes.MatchType): # type: ignore # incompatible with base class def __init__(self, **kw) -> None: ... class NUMERIC(_NumericType, sqltypes.NUMERIC): diff --git a/stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/base.pyi b/stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/base.pyi index 047d14bdc..b3ba752bb 100644 --- a/stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/base.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/base.pyi @@ -96,7 +96,7 @@ PGUuid = UUID class TSVECTOR(sqltypes.TypeEngine): __visit_name__: str -class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): # type: ignore[misc] +class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): # type: ignore # base classes incompatible native_enum: bool create_type: Any def __init__(self, *enums, **kw) -> None: ... diff --git a/stubs/SQLAlchemy/sqlalchemy/engine/base.pyi b/stubs/SQLAlchemy/sqlalchemy/engine/base.pyi index 32640ff8c..2ef9be686 100644 --- a/stubs/SQLAlchemy/sqlalchemy/engine/base.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/engine/base.pyi @@ -1,18 +1,35 @@ -from typing import Any +from _typeshed import Self +from abc import abstractmethod +from collections.abc import Mapping +from types import TracebackType +from typing import Any, Callable, TypeVar, overload -from .. import log -from .interfaces import Connectable as Connectable, ExceptionContext +from ..dbapi import DBAPIConnection +from ..log import Identified, _EchoFlag, echo_property +from ..pool import Pool +from ..sql.compiler import Compiled +from ..sql.ddl import DDLElement +from ..sql.elements import ClauseElement +from ..sql.functions import FunctionElement +from ..sql.schema import DefaultGenerator +from .cursor import CursorResult +from .interfaces import Connectable as Connectable, Dialect, ExceptionContext +from .url import URL from .util import TransactionalContext +_T = TypeVar("_T") + +_Executable = ClauseElement | FunctionElement | DDLElement | DefaultGenerator | Compiled + class Connection(Connectable): - engine: Any - dialect: Any + engine: Engine + dialect: Dialect should_close_with_result: bool dispatch: Any def __init__( self, - engine, - connection: Any | None = ..., + engine: Engine, + connection: DBAPIConnection | None = ..., close_with_result: bool = ..., _branch_from: Any | None = ..., _execution_options: Any | None = ..., @@ -20,42 +37,54 @@ class Connection(Connectable): _has_events: Any | None = ..., _allow_revalidate: bool = ..., ) -> None: ... - def schema_for_object(self, obj): ... - def __enter__(self): ... - def __exit__(self, type_, value, traceback) -> None: ... + def schema_for_object(self, obj) -> str | None: ... + def __enter__(self: Self) -> Self: ... + def __exit__( + self, type_: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None + ) -> None: ... def execution_options(self, **opt): ... def get_execution_options(self): ... @property - def closed(self): ... + def closed(self) -> bool: ... @property - def invalidated(self): ... + def invalidated(self) -> bool: ... @property - def connection(self): ... + def connection(self) -> DBAPIConnection: ... def get_isolation_level(self): ... @property def default_isolation_level(self): ... @property def info(self): ... def connect(self, close_with_result: bool = ...): ... # type: ignore[override] - def invalidate(self, exception: Any | None = ...): ... + def invalidate(self, exception: Exception | None = ...) -> None: ... def detach(self) -> None: ... - def begin(self): ... - def begin_nested(self): ... - def begin_twophase(self, xid: Any | None = ...): ... + def begin(self) -> Transaction: ... + def begin_nested(self) -> Transaction | None: ... + def begin_twophase(self, xid: Any | None = ...) -> TwoPhaseTransaction: ... def recover_twophase(self): ... def rollback_prepared(self, xid, recover: bool = ...) -> None: ... def commit_prepared(self, xid, recover: bool = ...) -> None: ... - def in_transaction(self): ... - def in_nested_transaction(self): ... - def get_transaction(self): ... - def get_nested_transaction(self): ... + def in_transaction(self) -> bool: ... + def in_nested_transaction(self) -> bool: ... + def get_transaction(self) -> Transaction | None: ... + def get_nested_transaction(self) -> Transaction | None: ... def close(self) -> None: ... - def scalar(self, object_, *multiparams, **params): ... + @overload + def scalar(self, object_: _Executable, *multiparams: Mapping[str, Any], **params: Any) -> Any: ... + @overload + def scalar(self, object_: str, *multiparams: Any | tuple[Any, ...] | Mapping[str, Any], **params: Any) -> Any: ... def scalars(self, object_, *multiparams, **params): ... - def execute(self, statement, *multiparams, **params): ... - def exec_driver_sql(self, statement, parameters: Any | None = ..., execution_options: Any | None = ...): ... - def transaction(self, callable_, *args, **kwargs): ... - def run_callable(self, callable_, *args, **kwargs): ... + @overload # type: ignore[override] + def execute(self, statement: _Executable, *multiparams: Mapping[str, Any], **params) -> CursorResult: ... + @overload + def execute(self, statement: str, *multiparams: Any | tuple[Any, ...] | Mapping[str, Any], **params) -> CursorResult: ... + def exec_driver_sql(self, statement: str, parameters: Any | None = ..., execution_options: Any | None = ...): ... + # TODO: + # def transaction(self, callable_: Callable[Concatenate[Connection, _P], _T], *args: _P.args, **kwargs: _P.kwargs) -> _T: ... + def transaction(self, callable_: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: ... + # TODO: + # def run_callable(self, callable_: Callable[Concatenate[Connection, _P], _T], *args: _P.args, **kwargs: _P.kwargs) -> _T: ... + def run_callable(self, callable_: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: ... class ExceptionContextImpl(ExceptionContext): engine: Any @@ -82,88 +111,105 @@ class ExceptionContextImpl(ExceptionContext): ) -> None: ... class Transaction(TransactionalContext): - def __init__(self, connection) -> None: ... + def __init__(self, connection: Connection) -> None: ... @property - def is_valid(self): ... + def is_valid(self) -> bool: ... def close(self) -> None: ... def rollback(self) -> None: ... def commit(self) -> None: ... + # The following field are technically not defined on Transaction, but on + # all sub-classes. + @property + @abstractmethod + def connection(self) -> Connection: ... + @property + @abstractmethod + def is_active(self) -> bool: ... class MarkerTransaction(Transaction): - connection: Any - def __init__(self, connection) -> None: ... + connection: Connection @property - def is_active(self): ... + def is_active(self) -> bool: ... class RootTransaction(Transaction): - connection: Any + connection: Connection is_active: bool - def __init__(self, connection) -> None: ... class NestedTransaction(Transaction): - connection: Any + connection: Connection is_active: bool - def __init__(self, connection) -> None: ... class TwoPhaseTransaction(RootTransaction): xid: Any - def __init__(self, connection, xid) -> None: ... + def __init__(self, connection: Connection, xid) -> None: ... def prepare(self) -> None: ... -class Engine(Connectable, log.Identified): - pool: Any - url: Any - dialect: Any - logging_name: Any - echo: Any - hide_parameters: Any +class Engine(Connectable, Identified): + pool: Pool + url: URL + dialect: Dialect + logging_name: str # only exists if not None during initialization + echo: echo_property + hide_parameters: bool def __init__( self, - pool, - dialect, - url, - logging_name: Any | None = ..., - echo: Any | None = ..., + pool: Pool, + dialect: Dialect, + url: str | URL, + logging_name: str | None = ..., + echo: _EchoFlag = ..., query_cache_size: int = ..., - execution_options: Any | None = ..., + execution_options: Mapping[str, Any] | None = ..., hide_parameters: bool = ..., ) -> None: ... @property - def engine(self): ... + def engine(self) -> Engine: ... def clear_compiled_cache(self) -> None: ... def update_execution_options(self, **opt) -> None: ... def execution_options(self, **opt): ... def get_execution_options(self): ... @property - def name(self): ... + def name(self) -> str: ... @property def driver(self): ... def dispose(self) -> None: ... class _trans_ctx: - conn: Any - transaction: Any - close_with_result: Any - def __init__(self, conn, transaction, close_with_result) -> None: ... - def __enter__(self): ... - def __exit__(self, type_, value, traceback) -> None: ... - def begin(self, close_with_result: bool = ...): ... - def transaction(self, callable_, *args, **kwargs): ... - def run_callable(self, callable_, *args, **kwargs): ... - def execute(self, statement, *multiparams, **params): ... - def scalar(self, statement, *multiparams, **params): ... - def connect(self, close_with_result: bool = ...): ... # type: ignore[override] - def table_names(self, schema: Any | None = ..., connection: Any | None = ...): ... - def has_table(self, table_name, schema: Any | None = ...): ... - def raw_connection(self, _connection: Any | None = ...): ... + conn: Connection + transaction: Transaction + close_with_result: bool + def __init__(self, conn: Connection, transaction: Transaction, close_with_result: bool) -> None: ... + def __enter__(self) -> Connection: ... + def __exit__( + self, type_: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None + ) -> None: ... + def begin(self, close_with_result: bool = ...) -> _trans_ctx: ... + # TODO: + # def transaction(self, callable_: Callable[Concatenate[Connection, _P], _T], *args: _P.args, **kwargs: _P.kwargs) -> _T | None: ... + def transaction(self, callable_: Callable[..., _T], *args: Any, **kwargs: Any) -> _T | None: ... + # TODO: + # def run_callable(self, callable_: Callable[Concatenate[Connection, _P], _T], *args: _P.args, **kwargs: _P.kwargs) -> _T: ... + def run_callable(self, callable_: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: ... + @overload # type: ignore[override] + def execute(self, statement: _Executable, *multiparams: Mapping[str, Any], **params: Any) -> CursorResult: ... + @overload + def execute(self, statement: str, *multiparams: Any | tuple[Any, ...] | Mapping[str, Any], **params: Any) -> CursorResult: ... + @overload # type: ignore[override] + def scalar(self, statement: _Executable, *multiparams: Mapping[str, Any], **params: Any) -> Any: ... + @overload + def scalar(self, statement: str, *multiparams: Any | tuple[Any, ...] | Mapping[str, Any], **params: Any) -> Any: ... + def connect(self, close_with_result: bool = ...) -> Connection: ... # type: ignore[override] + def table_names(self, schema: Any | None = ..., connection: Connection | None = ...): ... + def has_table(self, table_name: str, schema: Any | None = ...) -> bool: ... + def raw_connection(self, _connection: Connection | None = ...) -> DBAPIConnection: ... class OptionEngineMixin: - url: Any - dialect: Any - logging_name: Any - echo: Any - hide_parameters: Any + url: URL + dialect: Dialect + logging_name: str + echo: bool + hide_parameters: bool dispatch: Any def __init__(self, proxied, execution_options) -> None: ... - pool: Any + pool: Pool -class OptionEngine(OptionEngineMixin, Engine): ... +class OptionEngine(OptionEngineMixin, Engine): ... # type: ignore[misc] diff --git a/stubs/SQLAlchemy/sqlalchemy/engine/create.pyi b/stubs/SQLAlchemy/sqlalchemy/engine/create.pyi index b3f86c1fb..40c6b29fd 100644 --- a/stubs/SQLAlchemy/sqlalchemy/engine/create.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/engine/create.pyi @@ -1,2 +1,21 @@ -def create_engine(url, **kwargs): ... -def engine_from_config(configuration, prefix: str = ..., **kwargs): ... +from collections.abc import Mapping +from typing import Any, overload +from typing_extensions import Literal + +from ..future.engine import Engine as FutureEngine +from .base import Engine +from .mock import MockConnection +from .url import URL + +# Further kwargs are forwarded to the engine, dialect, or pool. +@overload +def create_engine(url: URL | str, *, strategy: Literal["mock"], **kwargs) -> MockConnection: ... # type: ignore[misc] +@overload +def create_engine( + url: URL | str, *, module: Any | None = ..., enable_from_linting: bool = ..., future: Literal[True], **kwargs +) -> FutureEngine: ... +@overload +def create_engine( + url: URL | str, *, module: Any | None = ..., enable_from_linting: bool = ..., future: Literal[False] = ..., **kwargs +) -> Engine: ... +def engine_from_config(configuration: Mapping[str, Any], prefix: str = ..., **kwargs) -> Engine: ... diff --git a/stubs/SQLAlchemy/sqlalchemy/engine/default.pyi b/stubs/SQLAlchemy/sqlalchemy/engine/default.pyi index 4f65bd5a2..9e521bbb6 100644 --- a/stubs/SQLAlchemy/sqlalchemy/engine/default.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/engine/default.pyi @@ -12,7 +12,7 @@ CACHING_DISABLED: Any NO_CACHE_KEY: Any NO_DIALECT_SUPPORT: Any -class DefaultDialect(interfaces.Dialect): +class DefaultDialect(interfaces.Dialect): # type: ignore[misc] execution_ctx_cls: ClassVar[type[interfaces.ExecutionContext]] statement_compiler: Any ddl_compiler: Any @@ -146,7 +146,7 @@ class _StrDateTime(_RendersLiteral, sqltypes.DateTime): ... class _StrDate(_RendersLiteral, sqltypes.Date): ... class _StrTime(_RendersLiteral, sqltypes.Time): ... -class StrCompileDialect(DefaultDialect): +class StrCompileDialect(DefaultDialect): # type: ignore[misc] statement_compiler: Any ddl_compiler: Any type_compiler: Any diff --git a/stubs/SQLAlchemy/sqlalchemy/engine/interfaces.pyi b/stubs/SQLAlchemy/sqlalchemy/engine/interfaces.pyi index 24c6ae829..cdf04f617 100644 --- a/stubs/SQLAlchemy/sqlalchemy/engine/interfaces.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/engine/interfaces.pyi @@ -1,13 +1,58 @@ -from typing import Any +from abc import abstractmethod +from collections.abc import Callable, Collection, Mapping +from typing import Any, ClassVar, overload -from ..sql.compiler import Compiled as Compiled, TypeCompiler as TypeCompiler +from ..dbapi import DBAPIConnection, DBAPICursor +from ..exc import StatementError +from ..sql.compiler import Compiled as Compiled, IdentifierPreparer, TypeCompiler as TypeCompiler +from ..sql.ddl import DDLElement +from ..sql.elements import ClauseElement +from ..sql.functions import FunctionElement +from ..sql.schema import DefaultGenerator +from .base import Connection, Engine +from .cursor import CursorResult +from .url import URL class Dialect: + # Sub-classes are required to have the following attributes: + name: str + driver: str + positional: bool + paramstyle: str + encoding: str + statement_compiler: Compiled + ddl_compiler: Compiled + server_version_info: tuple[Any, ...] + # Only available on supporting dialects: + # default_schema_name: str + execution_ctx_cls: ClassVar[type[ExecutionContext]] + execute_sequence_format: type[tuple[Any] | list[Any]] + preparer: IdentifierPreparer + supports_alter: bool + max_identifier_length: int + supports_sane_rowcount: bool + supports_sane_multi_rowcount: bool + preexecute_autoincrement_sequences: bool + implicit_returning: bool + colspecs: dict[Any, Any] + supports_default_values: bool + supports_sequences: bool + sequences_optional: bool + supports_native_enum: bool + supports_native_boolean: bool + dbapi_exception_translation_map: dict[Any, Any] + supports_statement_cache: bool - def create_connect_args(self, url) -> None: ... + @abstractmethod + def create_connect_args(self, url: URL) -> None: ... + def initialize(self, connection) -> None: ... + def on_connect_url(self, url) -> Callable[[DBAPIConnection], object] | None: ... + def on_connect(self) -> Callable[[DBAPIConnection], object] | None: ... + # The following methods all raise NotImplementedError, but not all + # dialects implement all methods, which is why they can't be marked + # as abstract. @classmethod def type_descriptor(cls, typeobj) -> None: ... - def initialize(self, connection) -> None: ... def get_columns(self, connection, table_name, schema: Any | None = ..., **kw) -> None: ... def get_pk_constraint(self, connection, table_name, schema: Any | None = ..., **kw) -> None: ... def get_foreign_keys(self, connection, table_name, schema: Any | None = ..., **kw) -> None: ... @@ -44,9 +89,7 @@ class Dialect: def do_execute(self, cursor, statement, parameters, context: Any | None = ...) -> None: ... def do_execute_no_params(self, cursor, statement, parameters, context: Any | None = ...) -> None: ... def is_disconnect(self, e, connection, cursor) -> None: ... - def connect(self, *cargs, **cparams) -> None: ... - def on_connect_url(self, url): ... - def on_connect(self) -> None: ... + def connect(self, *cargs, **cparams) -> DBAPIConnection: ... def reset_isolation_level(self, dbapi_conn) -> None: ... def set_isolation_level(self, dbapi_conn, level) -> None: ... def get_isolation_level(self, dbapi_conn) -> None: ... @@ -60,8 +103,8 @@ class Dialect: def get_driver_connection(self, connection) -> None: ... class CreateEnginePlugin: - url: Any - def __init__(self, url, kwargs) -> None: ... + url: URL + def __init__(self, url: URL, kwargs) -> None: ... def update_url(self, url) -> None: ... def handle_dialect_kwargs(self, dialect_cls, dialect_args) -> None: ... def handle_pool_kwargs(self, pool_cls, pool_args) -> None: ... @@ -79,22 +122,44 @@ class ExecutionContext: def get_rowcount(self) -> None: ... class Connectable: - def connect(self, **kwargs) -> None: ... - engine: Any - def execute(self, object_, *multiparams, **params) -> None: ... - def scalar(self, object_, *multiparams, **params) -> None: ... + @abstractmethod + def connect(self, **kwargs) -> Connection: ... + @property + def engine(self) -> Engine | None: ... + @abstractmethod + @overload + def execute( + self, + object_: ClauseElement | FunctionElement | DDLElement | DefaultGenerator | Compiled, + *multiparams: Mapping[str, Any], + **params: Any, + ) -> CursorResult: ... + @abstractmethod + @overload + def execute(self, object_: str, *multiparams: Any | tuple[Any, ...] | Mapping[str, Any], **params: Any) -> CursorResult: ... + @abstractmethod + @overload + def scalar( + self, + object_: ClauseElement | FunctionElement | DDLElement | DefaultGenerator | Compiled, + *multiparams: Mapping[str, Any], + **params: Any, + ) -> Any: ... + @abstractmethod + @overload + def scalar(self, object_: str, *multiparams: Any | tuple[Any, ...] | Mapping[str, Any], **params: Any) -> Any: ... class ExceptionContext: - connection: Any - engine: Any - cursor: Any - statement: Any - parameters: Any - original_exception: Any - sqlalchemy_exception: Any - chained_exception: Any - execution_context: Any - is_disconnect: Any + connection: Connection | None + engine: Engine | None + cursor: DBAPICursor | None + statement: str | None + parameters: Collection[Any] | None + original_exception: BaseException | None + sqlalchemy_exception: StatementError | None + chained_exception: BaseException | None + execution_context: ExecutionContext | None + is_disconnect: bool | None invalidate_pool_on_disconnect: bool class AdaptedConnection: diff --git a/stubs/SQLAlchemy/sqlalchemy/engine/mock.pyi b/stubs/SQLAlchemy/sqlalchemy/engine/mock.pyi index 47bb1b0fc..dc685760d 100644 --- a/stubs/SQLAlchemy/sqlalchemy/engine/mock.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/engine/mock.pyi @@ -1,18 +1,32 @@ -from typing import Any +from _typeshed import Self +from abc import abstractmethod +from collections.abc import Mapping +from typing import Any, overload -from . import base +from .base import _Executable +from .cursor import CursorResult +from .interfaces import Connectable, Dialect +from .url import URL -class MockConnection(base.Connectable): - def __init__(self, dialect, execute) -> None: ... - engine: Any - dialect: Any - name: Any +class MockConnection(Connectable): + def __init__(self, dialect: Dialect, execute) -> None: ... + @property + def engine(self: Self) -> Self: ... # type: ignore[override] + @property + def dialect(self) -> Dialect: ... + @property + def name(self) -> str: ... def schema_for_object(self, obj): ... def connect(self, **kwargs): ... def execution_options(self, **kw): ... def compiler(self, statement, parameters, **kwargs): ... def create(self, entity, **kwargs) -> None: ... def drop(self, entity, **kwargs) -> None: ... - def execute(self, object_, *multiparams, **params) -> None: ... + @abstractmethod + @overload + def execute(self, object_: _Executable, *multiparams: Mapping[str, Any], **params: Any) -> CursorResult: ... + @abstractmethod + @overload + def execute(self, object_: str, *multiparams: Any | tuple[Any, ...] | Mapping[str, Any], **params: Any) -> CursorResult: ... -def create_mock_engine(url, executor, **kw): ... +def create_mock_engine(url: URL | str, executor, **kw) -> MockConnection: ... diff --git a/stubs/SQLAlchemy/sqlalchemy/engine/result.pyi b/stubs/SQLAlchemy/sqlalchemy/engine/result.pyi index 027a86dd8..8aab5bc9d 100644 --- a/stubs/SQLAlchemy/sqlalchemy/engine/result.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/engine/result.pyi @@ -1,4 +1,4 @@ -from collections.abc import KeysView +from collections.abc import Generator, KeysView from typing import Any from ..sql.base import InPlaceGenerative @@ -36,26 +36,26 @@ class _WithKeys: class Result(_WithKeys, ResultInternal): def __init__(self, cursor_metadata) -> None: ... def close(self) -> None: ... - def yield_per(self, num) -> None: ... + def yield_per(self, num: int) -> None: ... def unique(self, strategy: Any | None = ...) -> None: ... def columns(self, *col_expressions): ... - def scalars(self, index: int = ...): ... - def mappings(self): ... + def scalars(self, index: int = ...) -> ScalarResult: ... + def mappings(self) -> MappingResult: ... def __iter__(self): ... def __next__(self): ... - def partitions(self, size: Any | None = ...) -> None: ... - def fetchall(self): ... - def fetchone(self): ... - def fetchmany(self, size: Any | None = ...): ... - def all(self): ... + def partitions(self, size: int | None = ...) -> Generator[Any, None, None]: ... + def fetchall(self) -> list[Any]: ... + def fetchone(self) -> Any | None: ... + def fetchmany(self, size: int | None = ...) -> list[Any]: ... + def all(self) -> list[Any]: ... def first(self): ... def one_or_none(self): ... def scalar_one(self): ... def scalar_one_or_none(self): ... def one(self): ... def scalar(self): ... - def freeze(self): ... - def merge(self, *others): ... + def freeze(self) -> FrozenResult: ... + def merge(self, *others) -> MergedResult: ... class FilterResult(ResultInternal): ... @@ -100,7 +100,7 @@ class IteratorResult(Result): raw: Any def __init__(self, cursor_metadata, iterator, raw: Any | None = ..., _source_supports_scalars: bool = ...) -> None: ... -def null_result(): ... +def null_result() -> IteratorResult: ... class ChunkedIteratorResult(IteratorResult): chunks: Any diff --git a/stubs/SQLAlchemy/sqlalchemy/engine/url.pyi b/stubs/SQLAlchemy/sqlalchemy/engine/url.pyi index 3ea48464f..0e0f535fd 100644 --- a/stubs/SQLAlchemy/sqlalchemy/engine/url.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/engine/url.pyi @@ -1,46 +1,60 @@ -from typing import Any +from _typeshed import Self, SupportsItems +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, NamedTuple -from ..util import memoized_property +from ..util import immutabledict +from .interfaces import Dialect -class URL: - def __new__(cls, *arg, **kw): ... +# stub-only helper class +class _URLTuple(NamedTuple): + drivername: str + username: str | None + password: str | object | None # object that produces a password when called with str() + host: str | None + port: int | None + database: str | None + query: immutabledict[str, str | tuple[str, ...]] + +_Query = Mapping[str, str | Sequence[str]] | Sequence[tuple[str, str | Sequence[str]]] + +class URL(_URLTuple): @classmethod def create( cls, - drivername, - username: Any | None = ..., - password: Any | None = ..., - host: Any | None = ..., - port: Any | None = ..., - database: Any | None = ..., - query=..., - ): ... + drivername: str, + username: str | None = ..., + password: str | object | None = ..., # object that produces a password when called with str() + host: str | None = ..., + port: int | None = ..., + database: str | None = ..., + query: _Query | None = ..., + ) -> URL: ... def set( - self, - drivername: Any | None = ..., - username: Any | None = ..., - password: Any | None = ..., - host: Any | None = ..., - port: Any | None = ..., - database: Any | None = ..., - query: Any | None = ..., - ): ... - def update_query_string(self, query_string, append: bool = ...): ... - def update_query_pairs(self, key_value_pairs, append: bool = ...): ... - def update_query_dict(self, query_parameters, append: bool = ...): ... - def difference_update_query(self, names): ... - @memoized_property - def normalized_query(self): ... - def __to_string__(self, hide_password: bool = ...): ... - def render_as_string(self, hide_password: bool = ...): ... - def __copy__(self): ... - def __deepcopy__(self, memo): ... - def __hash__(self): ... - def __eq__(self, other): ... - def __ne__(self, other): ... - def get_backend_name(self): ... - def get_driver_name(self): ... - def get_dialect(self): ... - def translate_connect_args(self, names: Any | None = ..., **kw): ... + self: Self, + drivername: str | None = ..., + username: str | None = ..., + password: str | object | None = ..., + host: str | None = ..., + port: int | None = ..., + database: str | None = ..., + query: _Query | None = ..., + ) -> Self: ... + def update_query_string(self: Self, query_string: str, append: bool = ...) -> Self: ... + def update_query_pairs(self: Self, key_value_pairs: Iterable[tuple[str, str]], append: bool = ...) -> Self: ... + def update_query_dict(self: Self, query_parameters: SupportsItems[str, str | Sequence[str]], append: bool = ...) -> Self: ... + def difference_update_query(self, names: Iterable[str]) -> URL: ... + @property + def normalized_query(self) -> immutabledict[str, tuple[str, ...]]: ... + def __to_string__(self, hide_password: bool = ...) -> str: ... + def render_as_string(self, hide_password: bool = ...) -> str: ... + def __copy__(self: Self) -> Self: ... + def __deepcopy__(self: Self, memo: object) -> Self: ... + def __hash__(self) -> int: ... + def __eq__(self, other: object) -> bool: ... + def __ne__(self, other: object) -> bool: ... + def get_backend_name(self) -> str: ... + def get_driver_name(self) -> str: ... + def get_dialect(self) -> type[Dialect]: ... + def translate_connect_args(self, names: list[str] | None = ..., **kw: str) -> dict[str, Any]: ... -def make_url(name_or_url): ... +def make_url(name_or_url: str | URL) -> URL: ... diff --git a/stubs/SQLAlchemy/sqlalchemy/engine/util.pyi b/stubs/SQLAlchemy/sqlalchemy/engine/util.pyi index 32aea3338..f711f0c83 100644 --- a/stubs/SQLAlchemy/sqlalchemy/engine/util.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/engine/util.pyi @@ -1,5 +1,12 @@ -def connection_memoize(key): ... +from _typeshed import Self +from collections.abc import Callable +from types import TracebackType +from typing import Any + +def connection_memoize(key: str) -> Callable[..., Any]: ... class TransactionalContext: - def __enter__(self): ... - def __exit__(self, type_, value, traceback) -> None: ... + def __enter__(self: Self) -> Self: ... + def __exit__( + self, type_: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None + ) -> None: ... diff --git a/stubs/SQLAlchemy/sqlalchemy/future/engine.pyi b/stubs/SQLAlchemy/sqlalchemy/future/engine.pyi index 1af2b62ea..bd4ff1dfa 100644 --- a/stubs/SQLAlchemy/sqlalchemy/future/engine.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/future/engine.pyi @@ -1,11 +1,19 @@ -from typing import Any +from typing import Any, overload +from typing_extensions import Literal from ..engine import Connection as _LegacyConnection, Engine as _LegacyEngine from ..engine.base import OptionEngineMixin +from ..engine.mock import MockConnection +from ..engine.url import URL NO_OPTIONS: Any -def create_engine(*arg, **kw): ... +@overload +def create_engine(url: URL | str, *, strategy: Literal["mock"], **kwargs) -> MockConnection: ... # type: ignore[misc] +@overload +def create_engine( + url: URL | str, *, module: Any | None = ..., enable_from_linting: bool = ..., future: bool = ..., **kwargs +) -> Engine: ... class Connection(_LegacyConnection): def begin(self): ... @@ -26,4 +34,4 @@ class Engine(_LegacyEngine): def begin(self) -> None: ... # type: ignore[override] def connect(self): ... -class OptionEngine(OptionEngineMixin, Engine): ... +class OptionEngine(OptionEngineMixin, Engine): ... # type: ignore[misc] diff --git a/stubs/SQLAlchemy/sqlalchemy/util/_collections.pyi b/stubs/SQLAlchemy/sqlalchemy/util/_collections.pyi index 49d369ea1..3642158c1 100644 --- a/stubs/SQLAlchemy/sqlalchemy/util/_collections.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/util/_collections.pyi @@ -1,12 +1,14 @@ import collections.abc import sys from _typeshed import Self, SupportsKeysAndGetItem -from collections.abc import Callable, Iterable, Iterator +from collections.abc import Callable, Iterable, Iterator, Mapping from typing import Any, Generic, NoReturn, TypeVar, overload from ..cimmutabledict import immutabledict as immutabledict from ..sql.elements import ColumnElement +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") _S = TypeVar("_S") _T = TypeVar("_T") @@ -19,9 +21,12 @@ class ImmutableContainer: def __setitem__(self, *arg: object, **kw: object) -> NoReturn: ... def __setattr__(self, *arg: object, **kw: object) -> NoReturn: ... -def coerce_to_immutabledict(d) -> immutabledict: ... +@overload +def coerce_to_immutabledict(d: None) -> immutabledict[Any, Any]: ... +@overload +def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]: ... -EMPTY_DICT: immutabledict +EMPTY_DICT: immutabledict[Any, Any] class FacadeDict(ImmutableContainer, dict[Any, Any]): clear: Any