diff --git a/stubs/SQLAlchemy/sqlalchemy/engine/base.pyi b/stubs/SQLAlchemy/sqlalchemy/engine/base.pyi index 2a43370b8..a4f152de5 100644 --- a/stubs/SQLAlchemy/sqlalchemy/engine/base.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/engine/base.pyi @@ -4,7 +4,7 @@ from abc import abstractmethod from collections.abc import Callable, Mapping from types import TracebackType from typing import Any, TypeVar, overload -from typing_extensions import TypeAlias +from typing_extensions import Concatenate, ParamSpec, TypeAlias from ..log import Identified, _EchoFlag, echo_property from ..pool import Pool @@ -19,6 +19,7 @@ from .url import URL from .util import TransactionalContext _T = TypeVar("_T") +_P = ParamSpec("_P") _Executable: TypeAlias = ClauseElement | FunctionElement | DDLElement | DefaultGenerator | Compiled @@ -80,12 +81,8 @@ class Connection(Connectable): @overload def execute(self, statement: str, *multiparams: Any | tuple[Any, ...] | Mapping[str, Any], **params) -> CursorResult: ... def exec_driver_sql(self, statement: str, parameters: Any | None = ..., execution_options: Any | None = ...): ... - # TODO: - # def transaction(self, callable_: Callable[Concatenate[Connection, _P], _T], *args: _P.args, **kwargs: _P.kwargs) -> _T: ... - def transaction(self, callable_: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: ... - # TODO: - # def run_callable(self, callable_: Callable[Concatenate[Connection, _P], _T], *args: _P.args, **kwargs: _P.kwargs) -> _T: ... - def run_callable(self, callable_: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: ... + def transaction(self, callable_: Callable[Concatenate[Connection, _P], _T], *args: _P.args, **kwargs: _P.kwargs) -> _T: ... + def run_callable(self, callable_: Callable[Concatenate[Connection, _P], _T], *args: _P.args, **kwargs: _P.kwargs) -> _T: ... class ExceptionContextImpl(ExceptionContext): engine: Any