From 44bcf5eed0f882c20d5d09d510f191aef9e918a6 Mon Sep 17 00:00:00 2001 From: Sebastian Rittau Date: Wed, 16 Feb 2022 15:46:11 +0100 Subject: [PATCH] Various SQLalchemy fixes and improvements (#7237) --- .../sqlalchemy/ext/horizontal_shard.pyi | 6 +- stubs/SQLAlchemy/sqlalchemy/orm/dynamic.pyi | 6 +- stubs/SQLAlchemy/sqlalchemy/orm/query.pyi | 83 ++++++++++--------- stubs/SQLAlchemy/sqlalchemy/orm/session.pyi | 11 ++- stubs/SQLAlchemy/sqlalchemy/sql/base.pyi | 5 +- 5 files changed, 63 insertions(+), 48 deletions(-) diff --git a/stubs/SQLAlchemy/sqlalchemy/ext/horizontal_shard.pyi b/stubs/SQLAlchemy/sqlalchemy/ext/horizontal_shard.pyi index 0ccd17ef2..b217b31b9 100644 --- a/stubs/SQLAlchemy/sqlalchemy/ext/horizontal_shard.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/ext/horizontal_shard.pyi @@ -1,9 +1,11 @@ -from typing import Any +from typing import Any, Generic, TypeVar from ..orm.query import Query from ..orm.session import Session -class ShardedQuery(Query): +_T = TypeVar("_T") + +class ShardedQuery(Query[_T], Generic[_T]): id_chooser: Any query_chooser: Any execute_chooser: Any diff --git a/stubs/SQLAlchemy/sqlalchemy/orm/dynamic.pyi b/stubs/SQLAlchemy/sqlalchemy/orm/dynamic.pyi index e730ae41b..801fe6aa9 100644 --- a/stubs/SQLAlchemy/sqlalchemy/orm/dynamic.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/orm/dynamic.pyi @@ -1,8 +1,10 @@ -from typing import Any +from typing import Any, Generic, TypeVar from . import attributes, strategies from .query import Query +_T = TypeVar("_T") + class DynaLoader(strategies.AbstractRelationshipLoader): logger: Any is_class_level: bool @@ -63,7 +65,7 @@ class AppenderMixin: def append(self, item) -> None: ... def remove(self, item) -> None: ... -class AppenderQuery(AppenderMixin, Query): ... +class AppenderQuery(AppenderMixin, Query[_T], Generic[_T]): ... def mixin_user_query(cls): ... diff --git a/stubs/SQLAlchemy/sqlalchemy/orm/query.pyi b/stubs/SQLAlchemy/sqlalchemy/orm/query.pyi index cbaf8c8bf..ea594a3eb 100644 --- a/stubs/SQLAlchemy/sqlalchemy/orm/query.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/orm/query.pyi @@ -1,4 +1,5 @@ -from typing import Any +from _typeshed import Self +from typing import Any, Generic, TypeVar from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import Executable @@ -9,14 +10,16 @@ from .util import aliased as aliased __all__ = ["Query", "QueryContext", "aliased"] -class Query(_SelectFromElements, SupportsCloneAnnotations, HasPrefixes, HasSuffixes, HasHints, Executable): +_T = TypeVar("_T") + +class Query(_SelectFromElements, SupportsCloneAnnotations, HasPrefixes, HasSuffixes, HasHints, Executable, Generic[_T]): logger: Any load_options: Any session: Any def __init__(self, entities, session: Any | None = ...) -> None: ... @property def statement(self): ... - def subquery(self, name: Any | None = ..., with_labels: bool = ..., reduce_columns: bool = ...): ... + def subquery(self, name: str | None = ..., with_labels: bool = ..., reduce_columns: bool = ...): ... def cte(self, name: Any | None = ..., recursive: bool = ..., nesting: bool = ...): ... def label(self, name): ... def as_scalar(self): ... @@ -24,79 +27,79 @@ class Query(_SelectFromElements, SupportsCloneAnnotations, HasPrefixes, HasSuffi @property def selectable(self): ... def __clause_element__(self): ... - def only_return_tuples(self, value) -> None: ... + def only_return_tuples(self: Self, value) -> Self: ... @property def is_single_entity(self): ... - def enable_eagerloads(self, value) -> None: ... + def enable_eagerloads(self: Self, value) -> Self: ... def with_labels(self): ... apply_labels: Any @property def get_label_style(self): ... def set_label_style(self, style): ... - def enable_assertions(self, value) -> None: ... + def enable_assertions(self: Self, value) -> Self: ... @property def whereclause(self): ... - def with_polymorphic(self, cls_or_mappers, selectable: Any | None = ..., polymorphic_on: Any | None = ...) -> None: ... - def yield_per(self, count) -> None: ... + def with_polymorphic(self: Self, cls_or_mappers, selectable: Any | None = ..., polymorphic_on: Any | None = ...) -> Self: ... + def yield_per(self: Self, count) -> Self: ... def get(self, ident): ... @property def lazy_loaded_from(self): ... - def correlate(self, *fromclauses) -> None: ... - def autoflush(self, setting) -> None: ... - def populate_existing(self) -> None: ... + def correlate(self: Self, *fromclauses) -> Self: ... + def autoflush(self: Self, setting) -> Self: ... + def populate_existing(self: Self) -> Self: ... def with_parent(self, instance, property: Any | None = ..., from_entity: Any | None = ...): ... - def add_entity(self, entity, alias: Any | None = ...) -> None: ... - def with_session(self, session) -> None: ... + def add_entity(self: Self, entity, alias: Any | None = ...) -> Self: ... + def with_session(self: Self, session) -> Self: ... def from_self(self, *entities): ... def values(self, *columns): ... def value(self, column): ... - def with_entities(self, *entities) -> None: ... - def add_columns(self, *column) -> None: ... + def with_entities(self: Self, *entities) -> Self: ... + def add_columns(self: Self, *column) -> Self: ... def add_column(self, column): ... - def options(self, *args) -> None: ... + def options(self: Self, *args) -> Self: ... def with_transformation(self, fn): ... def get_execution_options(self): ... - def execution_options(self, **kwargs) -> None: ... + def execution_options(self: Self, **kwargs) -> Self: ... def with_for_update( - self, read: bool = ..., nowait: bool = ..., of: Any | None = ..., skip_locked: bool = ..., key_share: bool = ... - ) -> None: ... - def params(self, *args, **kwargs) -> None: ... + self: Self, read: bool = ..., nowait: bool = ..., of: Any | None = ..., skip_locked: bool = ..., key_share: bool = ... + ) -> Self: ... + def params(self: Self, *args, **kwargs) -> Self: ... def where(self, *criterion): ... - def filter(self, *criterion) -> None: ... - def filter_by(self, **kwargs): ... - def order_by(self, *clauses) -> None: ... - def group_by(self, *clauses) -> None: ... - def having(self, criterion) -> None: ... + def filter(self: Self, *criterion) -> Self: ... + def filter_by(self: Self, **kwargs) -> Self: ... + def order_by(self: Self, *clauses) -> Self: ... + def group_by(self: Self, *clauses) -> Self: ... + def having(self: Self, criterion) -> Self: ... def union(self, *q): ... def union_all(self, *q): ... def intersect(self, *q): ... def intersect_all(self, *q): ... def except_(self, *q): ... def except_all(self, *q): ... - def join(self, target, *props, **kwargs) -> None: ... - def outerjoin(self, target, *props, **kwargs): ... - def reset_joinpoint(self) -> None: ... - def select_from(self, *from_obj) -> None: ... - def select_entity_from(self, from_obj) -> None: ... + def join(self: Self, target, *props, **kwargs) -> Self: ... + def outerjoin(self: Self, target, *props, **kwargs) -> Self: ... + def reset_joinpoint(self: Self) -> Self: ... + def select_from(self: Self, *from_obj) -> Self: ... + def select_entity_from(self: Self, from_obj) -> Self: ... def __getitem__(self, item): ... - def slice(self, start, stop) -> None: ... - def limit(self, limit) -> None: ... - def offset(self, offset) -> None: ... - def distinct(self, *expr) -> None: ... - def all(self): ... - def from_statement(self, statement) -> None: ... - def first(self): ... + def slice(self: Self, start, stop) -> Self: ... + def limit(self: Self, limit) -> Self: ... + def offset(self: Self, offset) -> Self: ... + def distinct(self: Self, *expr) -> Self: ... + def all(self) -> list[_T]: ... + def from_statement(self: Self, statement) -> Self: ... + def first(self) -> _T | None: ... def one_or_none(self): ... def one(self): ... - def scalar(self): ... + def scalar(self) -> Any: ... # type: ignore[override] def __iter__(self): ... @property def column_descriptions(self): ... def instances(self, result_proxy, context: Any | None = ...): ... def merge_result(self, iterator, load: bool = ...): ... def exists(self): ... - def count(self): ... - def delete(self, synchronize_session: str = ...): ... + def count(self) -> int: ... + def delete(self, synchronize_session: str = ...) -> int: ... def update(self, values, synchronize_session: str = ..., update_args: Any | None = ...): ... class FromStatement(GroupedElement, SelectBase, Executable): diff --git a/stubs/SQLAlchemy/sqlalchemy/orm/session.pyi b/stubs/SQLAlchemy/sqlalchemy/orm/session.pyi index bb7477e07..bf9538685 100644 --- a/stubs/SQLAlchemy/sqlalchemy/orm/session.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/orm/session.pyi @@ -1,5 +1,7 @@ +from collections.abc import Mapping from typing import Any +from ..engine.base import Connection from ..engine.util import TransactionalContext from ..util import MemoizedSlots, memoized_property @@ -111,9 +113,14 @@ class Session(_SessionClassMethods): def rollback(self) -> None: ... def commit(self) -> None: ... def prepare(self) -> None: ... + # TODO: bind_arguments could use a TypedDict def connection( - self, bind_arguments: Any | None = ..., close_with_result: bool = ..., execution_options: Any | None = ..., **kw - ): ... + self, + bind_arguments: Mapping[str, Any] | None = ..., + close_with_result: bool = ..., + execution_options: Mapping[str, Any] | None = ..., + **kw: Any, + ) -> Connection: ... def execute( self, statement, diff --git a/stubs/SQLAlchemy/sqlalchemy/sql/base.pyi b/stubs/SQLAlchemy/sqlalchemy/sql/base.pyi index 80036a800..fb006721d 100644 --- a/stubs/SQLAlchemy/sqlalchemy/sql/base.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/sql/base.pyi @@ -1,3 +1,4 @@ +from _typeshed import Self from collections.abc import MutableMapping from typing import Any @@ -96,8 +97,8 @@ class Executable(roles.StatementRole, Generative): is_text: bool is_delete: bool is_dml: bool - def options(self, *options) -> None: ... - def execution_options(self, **kw) -> None: ... + def options(self: Self, *options) -> Self: ... + def execution_options(self: Self, **kw) -> Self: ... def get_execution_options(self): ... def execute(self, *multiparams, **params): ... def scalar(self, *multiparams, **params): ...