Various SQLalchemy type improvements (#7238)

* Make ColumnOperators and ColumnElement generic
* Overload Session.query() return type
* Annotate ColumnOperators methods
This commit is contained in:
Sebastian Rittau
2022-02-17 04:03:48 +01:00
committed by GitHub
parent 63c20e3ce7
commit 4a0dabda1b
13 changed files with 82 additions and 63 deletions

View File

@@ -11,7 +11,7 @@ class CoerceUnicode(TypeDecorator):
def process_bind_param(self, value, dialect): ...
def bind_expression(self, bindvalue): ...
class _cast_on_2005(expression.ColumnElement):
class _cast_on_2005(expression.ColumnElement[Any]):
bindvalue: Any
def __init__(self, bindvalue) -> None: ...

View File

@@ -7,7 +7,7 @@ from ...sql import expression
def Any(other, arrexpr, operator=...): ...
def All(other, arrexpr, operator=...): ...
class array(expression.ClauseList, expression.ColumnElement):
class array(expression.ClauseList, expression.ColumnElement[_Any]):
__visit_name__: str
stringify_dialect: str
inherit_cache: bool

View File

@@ -3,7 +3,7 @@ from typing import Any
from ...sql import expression
from ...sql.schema import ColumnCollectionConstraint
class aggregate_order_by(expression.ColumnElement):
class aggregate_order_by(expression.ColumnElement[Any]):
__visit_name__: str
stringify_dialect: str
inherit_cache: bool

View File

@@ -76,7 +76,7 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance):
def __eq__(self, obj): ...
def __ne__(self, obj): ...
class ColumnAssociationProxyInstance(ColumnOperators, AssociationProxyInstance):
class ColumnAssociationProxyInstance(ColumnOperators[Any], AssociationProxyInstance):
def __eq__(self, other): ...
def operate(self, op, *other, **kwargs): ...

View File

@@ -1,9 +1,14 @@
from collections.abc import Mapping
from typing import Any
from typing import Any, TypeVar, overload
from ..engine.base import Connection
from ..engine.util import TransactionalContext
from ..sql.elements import ColumnElement
from ..sql.schema import Table
from ..util import MemoizedSlots, memoized_property
from .query import Query
_T = TypeVar("_T")
class _SessionClassMethods:
@classmethod
@@ -146,7 +151,14 @@ class Session(_SessionClassMethods):
_sa_skip_events: Any | None = ...,
_sa_skip_for_implicit_returning: bool = ...,
): ...
def query(self, *entities, **kwargs): ...
@overload
def query(self, entities: Table, **kwargs: Any) -> Query[Any]: ...
@overload
def query(self, entities: ColumnElement[_T], **kwargs: Any) -> Query[tuple[_T]]: ... # type: ignore[misc]
@overload
def query(self, *entities: ColumnElement[_T], **kwargs: Any) -> Query[tuple[_T, ...]]: ...
@overload
def query(self, *entities: type[_T], **kwargs: Any) -> Query[_T]: ...
@property
def no_autoflush(self) -> None: ...
def refresh(self, instance, attribute_names: Any | None = ..., with_for_update: Any | None = ...) -> None: ...

View File

@@ -66,7 +66,7 @@ class TypeCompiler:
def process(self, type_, **kw): ...
def visit_unsupported_compilation(self, element, err, **kw) -> None: ...
class _CompileLabel(elements.ColumnElement):
class _CompileLabel(elements.ColumnElement[Any]):
__visit_name__: str
element: Any
name: Any

View File

@@ -4,7 +4,7 @@ from . import elements
REQUIRED: Any
class _multiparam_column(elements.ColumnElement):
class _multiparam_column(elements.ColumnElement[Any]):
index: Any
key: Any
original: Any

View File

@@ -1,4 +1,5 @@
from typing import Any
from typing import Any, Generic, TypeVar
from typing_extensions import Literal
from .. import util
from ..util import HasMemoized, memoized_property
@@ -8,6 +9,8 @@ from .base import Executable, Immutable, SingletonConstant
from .traversals import HasCopyInternals, MemoizedHasCacheKey
from .visitors import Traversible
_T = TypeVar("_T")
def collate(expression, collation): ...
def between(expr, lower_bound, upper_bound, symmetric: bool = ...): ...
def literal(value, type_: Any | None = ...): ...
@@ -44,8 +47,9 @@ class ColumnElement(
roles.DMLColumnRole,
roles.DDLConstraintColumnRole,
roles.DDLExpressionRole,
operators.ColumnOperators,
operators.ColumnOperators[_T],
ClauseElement,
Generic[_T],
):
__visit_name__: str
primary_key: bool
@@ -77,7 +81,7 @@ class WrapsColumnExpression:
@property
def wrapped_column_expression(self) -> None: ...
class BindParameter(roles.InElementRole, ColumnElement):
class BindParameter(roles.InElementRole, ColumnElement[_T], Generic[_T]):
__visit_name__: str
inherit_cache: bool
key: Any
@@ -141,17 +145,17 @@ class TextClause(
def comparator(self): ...
def self_group(self, against: Any | None = ...): ...
class Null(SingletonConstant, roles.ConstExprRole, ColumnElement):
class Null(SingletonConstant, roles.ConstExprRole, ColumnElement[None]):
__visit_name__: str
@memoized_property
def type(self): ...
class False_(SingletonConstant, roles.ConstExprRole, ColumnElement):
class False_(SingletonConstant, roles.ConstExprRole, ColumnElement[Literal[False]]):
__visit_name__: str
@memoized_property
def type(self): ...
class True_(SingletonConstant, roles.ConstExprRole, ColumnElement):
class True_(SingletonConstant, roles.ConstExprRole, ColumnElement[Literal[True]]):
__visit_name__: str
@memoized_property
def type(self): ...
@@ -168,7 +172,7 @@ class ClauseList(roles.InElementRole, roles.OrderByRole, roles.ColumnsClauseRole
def append(self, clause) -> None: ...
def self_group(self, against: Any | None = ...): ...
class BooleanClauseList(ClauseList, ColumnElement):
class BooleanClauseList(ClauseList, ColumnElement[Any]):
__visit_name__: str
inherit_cache: bool
def __init__(self, *arg, **kw) -> None: ...
@@ -181,13 +185,13 @@ class BooleanClauseList(ClauseList, ColumnElement):
and_: Any
or_: Any
class Tuple(ClauseList, ColumnElement):
class Tuple(ClauseList, ColumnElement[Any]):
__visit_name__: str
type: Any
def __init__(self, *clauses, **kw) -> None: ...
def self_group(self, against: Any | None = ...): ...
class Case(ColumnElement):
class Case(ColumnElement[Any]):
__visit_name__: str
value: Any
type: Any
@@ -197,7 +201,7 @@ class Case(ColumnElement):
def literal_column(text, type_: Any | None = ...): ...
class Cast(WrapsColumnExpression, ColumnElement):
class Cast(WrapsColumnExpression, ColumnElement[Any]):
__visit_name__: str
type: Any
clause: Any
@@ -206,7 +210,7 @@ class Cast(WrapsColumnExpression, ColumnElement):
@property
def wrapped_column_expression(self): ...
class TypeCoerce(WrapsColumnExpression, ColumnElement):
class TypeCoerce(WrapsColumnExpression, ColumnElement[Any]):
__visit_name__: str
type: Any
clause: Any
@@ -217,24 +221,24 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement):
def wrapped_column_expression(self): ...
def self_group(self, against: Any | None = ...): ...
class Extract(ColumnElement):
class Extract(ColumnElement[Any]):
__visit_name__: str
type: Any
field: Any
expr: Any
def __init__(self, field, expr, **kwargs) -> None: ...
class _label_reference(ColumnElement):
class _label_reference(ColumnElement[Any]):
__visit_name__: str
element: Any
def __init__(self, element) -> None: ...
class _textual_label_reference(ColumnElement):
class _textual_label_reference(ColumnElement[Any]):
__visit_name__: str
element: Any
def __init__(self, element) -> None: ...
class UnaryExpression(ColumnElement):
class UnaryExpression(ColumnElement[Any]):
__visit_name__: str
operator: Any
modifier: Any
@@ -269,7 +273,7 @@ class AsBoolean(WrapsColumnExpression, UnaryExpression):
def wrapped_column_expression(self): ...
def self_group(self, against: Any | None = ...): ...
class BinaryExpression(ColumnElement):
class BinaryExpression(ColumnElement[Any]):
__visit_name__: str
left: Any
right: Any
@@ -286,7 +290,7 @@ class BinaryExpression(ColumnElement):
def is_comparison(self): ...
def self_group(self, against: Any | None = ...): ...
class Slice(ColumnElement):
class Slice(ColumnElement[Any]):
__visit_name__: str
start: Any
stop: Any
@@ -302,7 +306,7 @@ class GroupedElement(ClauseElement):
__visit_name__: str
def self_group(self, against: Any | None = ...): ...
class Grouping(GroupedElement, ColumnElement):
class Grouping(GroupedElement, ColumnElement[Any]):
element: Any
type: Any
def __init__(self, element) -> None: ...
@@ -311,7 +315,7 @@ class Grouping(GroupedElement, ColumnElement):
RANGE_UNBOUNDED: Any
RANGE_CURRENT: Any
class Over(ColumnElement):
class Over(ColumnElement[Any]):
__visit_name__: str
order_by: Any
partition_by: Any
@@ -330,7 +334,7 @@ class Over(ColumnElement):
@memoized_property
def type(self): ...
class WithinGroup(ColumnElement):
class WithinGroup(ColumnElement[Any]):
__visit_name__: str
order_by: Any
element: Any
@@ -342,7 +346,7 @@ class WithinGroup(ColumnElement):
@memoized_property
def type(self): ...
class FunctionFilter(ColumnElement):
class FunctionFilter(ColumnElement[Any]):
__visit_name__: str
criterion: Any
func: Any
@@ -355,7 +359,7 @@ class FunctionFilter(ColumnElement):
@memoized_property
def type(self): ...
class Label(roles.LabeledColumnExprRole, ColumnElement):
class Label(roles.LabeledColumnExprRole, ColumnElement[Any]):
__visit_name__: str
name: Any
key: Any
@@ -371,7 +375,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
@property
def foreign_keys(self): ...
class NamedColumn(ColumnElement):
class NamedColumn(ColumnElement[Any]):
is_literal: bool
table: Any
@memoized_property
@@ -399,7 +403,7 @@ class TableValuedColumn(NamedColumn):
type: Any
def __init__(self, scalar_alias, type_) -> None: ...
class CollationClause(ColumnElement):
class CollationClause(ColumnElement[Any]):
__visit_name__: str
collation: Any
def __init__(self, collation) -> None: ...

View File

@@ -8,7 +8,7 @@ from .visitors import TraversibleType
def register_function(identifier, fn, package: str = ...) -> None: ...
class FunctionElement(Executable, ColumnElement, FromClause, Generative): # type: ignore[misc]
class FunctionElement(Executable, ColumnElement[Any], FromClause, Generative): # type: ignore[misc]
packagenames: Any
clause_expr: Any
def __init__(self, *clauses, **kwargs) -> None: ...

View File

@@ -94,7 +94,7 @@ class AnalyzedFunction:
closure_bindparams: Any
def __init__(self, analyzed_code, lambda_element, apply_propagate_attrs, fn) -> None: ...
class PyWrapper(ColumnOperators):
class PyWrapper(ColumnOperators[Any]):
fn: Any
track_bound_values: Any
def __init__(

View File

@@ -1,5 +1,8 @@
from collections.abc import Container, Iterable
from operator import truediv
from typing import Any
from typing import Any, Generic, TypeVar
_T = TypeVar("_T")
div = truediv
@@ -33,40 +36,40 @@ class custom_op:
def __hash__(self): ...
def __call__(self, left, right, **kw): ...
class ColumnOperators(Operators):
class ColumnOperators(Operators, Generic[_T]):
timetuple: Any
def __lt__(self, other): ...
def __le__(self, other): ...
def __lt__(self, other: _T | ColumnOperators[_T] | None): ...
def __le__(self, other: _T | ColumnOperators[_T] | None): ...
__hash__: Any
def __eq__(self, other): ...
def __ne__(self, other): ...
def __eq__(self, other: _T | ColumnOperators[_T] | None): ... # type: ignore[override]
def __ne__(self, other: _T | ColumnOperators[_T] | None): ... # type: ignore[override]
def is_distinct_from(self, other): ...
def is_not_distinct_from(self, other): ...
isnot_distinct_from: Any
def __gt__(self, other): ...
def __ge__(self, other): ...
isnot_distinct_from = is_not_distinct_from
def __gt__(self, other: _T | ColumnOperators[_T] | None): ...
def __ge__(self, other: _T | ColumnOperators[_T] | None): ...
def __neg__(self): ...
def __contains__(self, other): ...
def __getitem__(self, index): ...
def __getitem__(self, index: int): ...
def __lshift__(self, other): ...
def __rshift__(self, other): ...
def concat(self, other): ...
def like(self, other, escape: Any | None = ...): ...
def ilike(self, other, escape: Any | None = ...): ...
def in_(self, other): ...
def not_in(self, other): ...
notin_: Any
def not_like(self, other, escape: Any | None = ...): ...
notlike: Any
def not_ilike(self, other, escape: Any | None = ...): ...
notilike: Any
def is_(self, other): ...
def is_not(self, other): ...
isnot: Any
def startswith(self, other, **kwargs): ...
def endswith(self, other, **kwargs): ...
def contains(self, other, **kwargs): ...
def match(self, other, **kwargs): ...
def concat(self, other: _T | ColumnOperators[_T] | None): ...
def like(self, other: _T, escape: str | None = ...): ...
def ilike(self, other: _T, escape: str | None = ...): ...
def in_(self, other: Container[_T] | Iterable[_T]): ...
def not_in(self, other: Container[_T] | Iterable[_T]): ...
notin_ = not_in
def not_like(self, other: _T, escape: str | None = ...): ...
notlike = not_like
def not_ilike(self, other: _T, escape: str | None = ...): ...
notilike = not_ilike
def is_(self, other: _T): ...
def is_not(self, other: _T): ...
isnot = is_not
def startswith(self, other: str, **kwargs): ...
def endswith(self, other: str, **kwargs): ...
def contains(self, other: str, **kwargs): ...
def match(self, other: str, **kwargs): ...
def regexp_match(self, pattern, flags: Any | None = ...): ...
def regexp_replace(self, pattern, replacement, flags: Any | None = ...): ...
def desc(self): ...

View File

@@ -18,7 +18,7 @@ from .elements import (
literal_column as literal_column,
)
class _OffsetLimitParam(BindParameter):
class _OffsetLimitParam(BindParameter[Any]):
inherit_cache: bool
def subquery(alias, *args, **kwargs): ...

View File

@@ -178,7 +178,7 @@ class WeakPopulateDict(dict[Any, Any]):
column_set = set
column_dict = dict
ordered_column_set = OrderedSet[ColumnElement]
ordered_column_set = OrderedSet[ColumnElement[Any]]
def unique_list(seq: Iterable[_T], hashfunc: Callable[[_T], Any] | None = ...) -> list[_T]: ...