diff --git a/stubs/SQLAlchemy/sqlalchemy/dialects/mssql/information_schema.pyi b/stubs/SQLAlchemy/sqlalchemy/dialects/mssql/information_schema.pyi index ec1a2504c..0a97a197f 100644 --- a/stubs/SQLAlchemy/sqlalchemy/dialects/mssql/information_schema.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/dialects/mssql/information_schema.pyi @@ -11,7 +11,7 @@ class CoerceUnicode(TypeDecorator): def process_bind_param(self, value, dialect): ... def bind_expression(self, bindvalue): ... -class _cast_on_2005(expression.ColumnElement): +class _cast_on_2005(expression.ColumnElement[Any]): bindvalue: Any def __init__(self, bindvalue) -> None: ... diff --git a/stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/array.pyi b/stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/array.pyi index f8c96a4a5..ff186142a 100644 --- a/stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/array.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/array.pyi @@ -7,7 +7,7 @@ from ...sql import expression def Any(other, arrexpr, operator=...): ... def All(other, arrexpr, operator=...): ... -class array(expression.ClauseList, expression.ColumnElement): +class array(expression.ClauseList, expression.ColumnElement[_Any]): __visit_name__: str stringify_dialect: str inherit_cache: bool diff --git a/stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/ext.pyi b/stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/ext.pyi index b205c953c..66fd97542 100644 --- a/stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/ext.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/ext.pyi @@ -3,7 +3,7 @@ from typing import Any from ...sql import expression from ...sql.schema import ColumnCollectionConstraint -class aggregate_order_by(expression.ColumnElement): +class aggregate_order_by(expression.ColumnElement[Any]): __visit_name__: str stringify_dialect: str inherit_cache: bool diff --git a/stubs/SQLAlchemy/sqlalchemy/ext/associationproxy.pyi b/stubs/SQLAlchemy/sqlalchemy/ext/associationproxy.pyi index dd2e82904..bc6e384dd 100644 --- a/stubs/SQLAlchemy/sqlalchemy/ext/associationproxy.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/ext/associationproxy.pyi @@ -76,7 +76,7 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): def __eq__(self, obj): ... def __ne__(self, obj): ... -class ColumnAssociationProxyInstance(ColumnOperators, AssociationProxyInstance): +class ColumnAssociationProxyInstance(ColumnOperators[Any], AssociationProxyInstance): def __eq__(self, other): ... def operate(self, op, *other, **kwargs): ... diff --git a/stubs/SQLAlchemy/sqlalchemy/orm/session.pyi b/stubs/SQLAlchemy/sqlalchemy/orm/session.pyi index bf9538685..18aa5a589 100644 --- a/stubs/SQLAlchemy/sqlalchemy/orm/session.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/orm/session.pyi @@ -1,9 +1,14 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, TypeVar, overload from ..engine.base import Connection from ..engine.util import TransactionalContext +from ..sql.elements import ColumnElement +from ..sql.schema import Table from ..util import MemoizedSlots, memoized_property +from .query import Query + +_T = TypeVar("_T") class _SessionClassMethods: @classmethod @@ -146,7 +151,14 @@ class Session(_SessionClassMethods): _sa_skip_events: Any | None = ..., _sa_skip_for_implicit_returning: bool = ..., ): ... - def query(self, *entities, **kwargs): ... + @overload + def query(self, entities: Table, **kwargs: Any) -> Query[Any]: ... + @overload + def query(self, entities: ColumnElement[_T], **kwargs: Any) -> Query[tuple[_T]]: ... # type: ignore[misc] + @overload + def query(self, *entities: ColumnElement[_T], **kwargs: Any) -> Query[tuple[_T, ...]]: ... + @overload + def query(self, *entities: type[_T], **kwargs: Any) -> Query[_T]: ... @property def no_autoflush(self) -> None: ... def refresh(self, instance, attribute_names: Any | None = ..., with_for_update: Any | None = ...) -> None: ... diff --git a/stubs/SQLAlchemy/sqlalchemy/sql/compiler.pyi b/stubs/SQLAlchemy/sqlalchemy/sql/compiler.pyi index eb96eab18..66cac257c 100644 --- a/stubs/SQLAlchemy/sqlalchemy/sql/compiler.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/sql/compiler.pyi @@ -66,7 +66,7 @@ class TypeCompiler: def process(self, type_, **kw): ... def visit_unsupported_compilation(self, element, err, **kw) -> None: ... -class _CompileLabel(elements.ColumnElement): +class _CompileLabel(elements.ColumnElement[Any]): __visit_name__: str element: Any name: Any diff --git a/stubs/SQLAlchemy/sqlalchemy/sql/crud.pyi b/stubs/SQLAlchemy/sqlalchemy/sql/crud.pyi index 0e263aca4..a39d81741 100644 --- a/stubs/SQLAlchemy/sqlalchemy/sql/crud.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/sql/crud.pyi @@ -4,7 +4,7 @@ from . import elements REQUIRED: Any -class _multiparam_column(elements.ColumnElement): +class _multiparam_column(elements.ColumnElement[Any]): index: Any key: Any original: Any diff --git a/stubs/SQLAlchemy/sqlalchemy/sql/elements.pyi b/stubs/SQLAlchemy/sqlalchemy/sql/elements.pyi index ef7345e32..4d3f85442 100644 --- a/stubs/SQLAlchemy/sqlalchemy/sql/elements.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/sql/elements.pyi @@ -1,4 +1,5 @@ -from typing import Any +from typing import Any, Generic, TypeVar +from typing_extensions import Literal from .. import util from ..util import HasMemoized, memoized_property @@ -8,6 +9,8 @@ from .base import Executable, Immutable, SingletonConstant from .traversals import HasCopyInternals, MemoizedHasCacheKey from .visitors import Traversible +_T = TypeVar("_T") + def collate(expression, collation): ... def between(expr, lower_bound, upper_bound, symmetric: bool = ...): ... def literal(value, type_: Any | None = ...): ... @@ -44,8 +47,9 @@ class ColumnElement( roles.DMLColumnRole, roles.DDLConstraintColumnRole, roles.DDLExpressionRole, - operators.ColumnOperators, + operators.ColumnOperators[_T], ClauseElement, + Generic[_T], ): __visit_name__: str primary_key: bool @@ -77,7 +81,7 @@ class WrapsColumnExpression: @property def wrapped_column_expression(self) -> None: ... -class BindParameter(roles.InElementRole, ColumnElement): +class BindParameter(roles.InElementRole, ColumnElement[_T], Generic[_T]): __visit_name__: str inherit_cache: bool key: Any @@ -141,17 +145,17 @@ class TextClause( def comparator(self): ... def self_group(self, against: Any | None = ...): ... -class Null(SingletonConstant, roles.ConstExprRole, ColumnElement): +class Null(SingletonConstant, roles.ConstExprRole, ColumnElement[None]): __visit_name__: str @memoized_property def type(self): ... -class False_(SingletonConstant, roles.ConstExprRole, ColumnElement): +class False_(SingletonConstant, roles.ConstExprRole, ColumnElement[Literal[False]]): __visit_name__: str @memoized_property def type(self): ... -class True_(SingletonConstant, roles.ConstExprRole, ColumnElement): +class True_(SingletonConstant, roles.ConstExprRole, ColumnElement[Literal[True]]): __visit_name__: str @memoized_property def type(self): ... @@ -168,7 +172,7 @@ class ClauseList(roles.InElementRole, roles.OrderByRole, roles.ColumnsClauseRole def append(self, clause) -> None: ... def self_group(self, against: Any | None = ...): ... -class BooleanClauseList(ClauseList, ColumnElement): +class BooleanClauseList(ClauseList, ColumnElement[Any]): __visit_name__: str inherit_cache: bool def __init__(self, *arg, **kw) -> None: ... @@ -181,13 +185,13 @@ class BooleanClauseList(ClauseList, ColumnElement): and_: Any or_: Any -class Tuple(ClauseList, ColumnElement): +class Tuple(ClauseList, ColumnElement[Any]): __visit_name__: str type: Any def __init__(self, *clauses, **kw) -> None: ... def self_group(self, against: Any | None = ...): ... -class Case(ColumnElement): +class Case(ColumnElement[Any]): __visit_name__: str value: Any type: Any @@ -197,7 +201,7 @@ class Case(ColumnElement): def literal_column(text, type_: Any | None = ...): ... -class Cast(WrapsColumnExpression, ColumnElement): +class Cast(WrapsColumnExpression, ColumnElement[Any]): __visit_name__: str type: Any clause: Any @@ -206,7 +210,7 @@ class Cast(WrapsColumnExpression, ColumnElement): @property def wrapped_column_expression(self): ... -class TypeCoerce(WrapsColumnExpression, ColumnElement): +class TypeCoerce(WrapsColumnExpression, ColumnElement[Any]): __visit_name__: str type: Any clause: Any @@ -217,24 +221,24 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement): def wrapped_column_expression(self): ... def self_group(self, against: Any | None = ...): ... -class Extract(ColumnElement): +class Extract(ColumnElement[Any]): __visit_name__: str type: Any field: Any expr: Any def __init__(self, field, expr, **kwargs) -> None: ... -class _label_reference(ColumnElement): +class _label_reference(ColumnElement[Any]): __visit_name__: str element: Any def __init__(self, element) -> None: ... -class _textual_label_reference(ColumnElement): +class _textual_label_reference(ColumnElement[Any]): __visit_name__: str element: Any def __init__(self, element) -> None: ... -class UnaryExpression(ColumnElement): +class UnaryExpression(ColumnElement[Any]): __visit_name__: str operator: Any modifier: Any @@ -269,7 +273,7 @@ class AsBoolean(WrapsColumnExpression, UnaryExpression): def wrapped_column_expression(self): ... def self_group(self, against: Any | None = ...): ... -class BinaryExpression(ColumnElement): +class BinaryExpression(ColumnElement[Any]): __visit_name__: str left: Any right: Any @@ -286,7 +290,7 @@ class BinaryExpression(ColumnElement): def is_comparison(self): ... def self_group(self, against: Any | None = ...): ... -class Slice(ColumnElement): +class Slice(ColumnElement[Any]): __visit_name__: str start: Any stop: Any @@ -302,7 +306,7 @@ class GroupedElement(ClauseElement): __visit_name__: str def self_group(self, against: Any | None = ...): ... -class Grouping(GroupedElement, ColumnElement): +class Grouping(GroupedElement, ColumnElement[Any]): element: Any type: Any def __init__(self, element) -> None: ... @@ -311,7 +315,7 @@ class Grouping(GroupedElement, ColumnElement): RANGE_UNBOUNDED: Any RANGE_CURRENT: Any -class Over(ColumnElement): +class Over(ColumnElement[Any]): __visit_name__: str order_by: Any partition_by: Any @@ -330,7 +334,7 @@ class Over(ColumnElement): @memoized_property def type(self): ... -class WithinGroup(ColumnElement): +class WithinGroup(ColumnElement[Any]): __visit_name__: str order_by: Any element: Any @@ -342,7 +346,7 @@ class WithinGroup(ColumnElement): @memoized_property def type(self): ... -class FunctionFilter(ColumnElement): +class FunctionFilter(ColumnElement[Any]): __visit_name__: str criterion: Any func: Any @@ -355,7 +359,7 @@ class FunctionFilter(ColumnElement): @memoized_property def type(self): ... -class Label(roles.LabeledColumnExprRole, ColumnElement): +class Label(roles.LabeledColumnExprRole, ColumnElement[Any]): __visit_name__: str name: Any key: Any @@ -371,7 +375,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): @property def foreign_keys(self): ... -class NamedColumn(ColumnElement): +class NamedColumn(ColumnElement[Any]): is_literal: bool table: Any @memoized_property @@ -399,7 +403,7 @@ class TableValuedColumn(NamedColumn): type: Any def __init__(self, scalar_alias, type_) -> None: ... -class CollationClause(ColumnElement): +class CollationClause(ColumnElement[Any]): __visit_name__: str collation: Any def __init__(self, collation) -> None: ... diff --git a/stubs/SQLAlchemy/sqlalchemy/sql/functions.pyi b/stubs/SQLAlchemy/sqlalchemy/sql/functions.pyi index fe0c3ef16..873b414b4 100644 --- a/stubs/SQLAlchemy/sqlalchemy/sql/functions.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/sql/functions.pyi @@ -8,7 +8,7 @@ from .visitors import TraversibleType def register_function(identifier, fn, package: str = ...) -> None: ... -class FunctionElement(Executable, ColumnElement, FromClause, Generative): # type: ignore[misc] +class FunctionElement(Executable, ColumnElement[Any], FromClause, Generative): # type: ignore[misc] packagenames: Any clause_expr: Any def __init__(self, *clauses, **kwargs) -> None: ... diff --git a/stubs/SQLAlchemy/sqlalchemy/sql/lambdas.pyi b/stubs/SQLAlchemy/sqlalchemy/sql/lambdas.pyi index 6f65d391b..525111d0d 100644 --- a/stubs/SQLAlchemy/sqlalchemy/sql/lambdas.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/sql/lambdas.pyi @@ -94,7 +94,7 @@ class AnalyzedFunction: closure_bindparams: Any def __init__(self, analyzed_code, lambda_element, apply_propagate_attrs, fn) -> None: ... -class PyWrapper(ColumnOperators): +class PyWrapper(ColumnOperators[Any]): fn: Any track_bound_values: Any def __init__( diff --git a/stubs/SQLAlchemy/sqlalchemy/sql/operators.pyi b/stubs/SQLAlchemy/sqlalchemy/sql/operators.pyi index 07289f632..4e1042a21 100644 --- a/stubs/SQLAlchemy/sqlalchemy/sql/operators.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/sql/operators.pyi @@ -1,5 +1,8 @@ +from collections.abc import Container, Iterable from operator import truediv -from typing import Any +from typing import Any, Generic, TypeVar + +_T = TypeVar("_T") div = truediv @@ -33,40 +36,40 @@ class custom_op: def __hash__(self): ... def __call__(self, left, right, **kw): ... -class ColumnOperators(Operators): +class ColumnOperators(Operators, Generic[_T]): timetuple: Any - def __lt__(self, other): ... - def __le__(self, other): ... + def __lt__(self, other: _T | ColumnOperators[_T] | None): ... + def __le__(self, other: _T | ColumnOperators[_T] | None): ... __hash__: Any - def __eq__(self, other): ... - def __ne__(self, other): ... + def __eq__(self, other: _T | ColumnOperators[_T] | None): ... # type: ignore[override] + def __ne__(self, other: _T | ColumnOperators[_T] | None): ... # type: ignore[override] def is_distinct_from(self, other): ... def is_not_distinct_from(self, other): ... - isnot_distinct_from: Any - def __gt__(self, other): ... - def __ge__(self, other): ... + isnot_distinct_from = is_not_distinct_from + def __gt__(self, other: _T | ColumnOperators[_T] | None): ... + def __ge__(self, other: _T | ColumnOperators[_T] | None): ... def __neg__(self): ... def __contains__(self, other): ... - def __getitem__(self, index): ... + def __getitem__(self, index: int): ... def __lshift__(self, other): ... def __rshift__(self, other): ... - def concat(self, other): ... - def like(self, other, escape: Any | None = ...): ... - def ilike(self, other, escape: Any | None = ...): ... - def in_(self, other): ... - def not_in(self, other): ... - notin_: Any - def not_like(self, other, escape: Any | None = ...): ... - notlike: Any - def not_ilike(self, other, escape: Any | None = ...): ... - notilike: Any - def is_(self, other): ... - def is_not(self, other): ... - isnot: Any - def startswith(self, other, **kwargs): ... - def endswith(self, other, **kwargs): ... - def contains(self, other, **kwargs): ... - def match(self, other, **kwargs): ... + def concat(self, other: _T | ColumnOperators[_T] | None): ... + def like(self, other: _T, escape: str | None = ...): ... + def ilike(self, other: _T, escape: str | None = ...): ... + def in_(self, other: Container[_T] | Iterable[_T]): ... + def not_in(self, other: Container[_T] | Iterable[_T]): ... + notin_ = not_in + def not_like(self, other: _T, escape: str | None = ...): ... + notlike = not_like + def not_ilike(self, other: _T, escape: str | None = ...): ... + notilike = not_ilike + def is_(self, other: _T): ... + def is_not(self, other: _T): ... + isnot = is_not + def startswith(self, other: str, **kwargs): ... + def endswith(self, other: str, **kwargs): ... + def contains(self, other: str, **kwargs): ... + def match(self, other: str, **kwargs): ... def regexp_match(self, pattern, flags: Any | None = ...): ... def regexp_replace(self, pattern, replacement, flags: Any | None = ...): ... def desc(self): ... diff --git a/stubs/SQLAlchemy/sqlalchemy/sql/selectable.pyi b/stubs/SQLAlchemy/sqlalchemy/sql/selectable.pyi index 84ac9413a..c1a9d2e18 100644 --- a/stubs/SQLAlchemy/sqlalchemy/sql/selectable.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/sql/selectable.pyi @@ -18,7 +18,7 @@ from .elements import ( literal_column as literal_column, ) -class _OffsetLimitParam(BindParameter): +class _OffsetLimitParam(BindParameter[Any]): inherit_cache: bool def subquery(alias, *args, **kwargs): ... diff --git a/stubs/SQLAlchemy/sqlalchemy/util/_collections.pyi b/stubs/SQLAlchemy/sqlalchemy/util/_collections.pyi index 3642158c1..e764687d1 100644 --- a/stubs/SQLAlchemy/sqlalchemy/util/_collections.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/util/_collections.pyi @@ -178,7 +178,7 @@ class WeakPopulateDict(dict[Any, Any]): column_set = set column_dict = dict -ordered_column_set = OrderedSet[ColumnElement] +ordered_column_set = OrderedSet[ColumnElement[Any]] def unique_list(seq: Iterable[_T], hashfunc: Callable[[_T], Any] | None = ...) -> list[_T]: ...