From 426ce065b224fdc0ac53674b6d077079da95a5c2 Mon Sep 17 00:00:00 2001 From: Sebastian Rittau Date: Wed, 22 Dec 2021 16:49:20 +0100 Subject: [PATCH] Various small SQLAlchemy type improvements (#6623) --- .../SQLAlchemy/@tests/stubtest_allowlist.txt | 3 + stubs/SQLAlchemy/sqlalchemy/exc.pyi | 49 +++----- .../sqlalchemy/ext/declarative/__init__.pyi | 14 ++- stubs/SQLAlchemy/sqlalchemy/log.pyi | 29 +++-- .../sqlalchemy/util/_collections.pyi | 119 ++++++++++-------- tests/pytype_exclude_list.txt | 27 ++++ 6 files changed, 138 insertions(+), 103 deletions(-) diff --git a/stubs/SQLAlchemy/@tests/stubtest_allowlist.txt b/stubs/SQLAlchemy/@tests/stubtest_allowlist.txt index 5f354e544..9804f4f66 100644 --- a/stubs/SQLAlchemy/@tests/stubtest_allowlist.txt +++ b/stubs/SQLAlchemy/@tests/stubtest_allowlist.txt @@ -7,6 +7,9 @@ sqlalchemy.util.langhelpers._symbol.__new__ sqlalchemy.util._collections.* sqlalchemy.util.compat.* +# forwards arguments to another function +sqlalchemy.ext.declarative.as_declarative + # stdlib re-exports with stubtest issues sqlalchemy.orm.collections.InstrumentedList.* sqlalchemy.orm.collections.InstrumentedSet.* diff --git a/stubs/SQLAlchemy/sqlalchemy/exc.pyi b/stubs/SQLAlchemy/sqlalchemy/exc.pyi index 5f190cbaa..b8775e9ad 100644 --- a/stubs/SQLAlchemy/sqlalchemy/exc.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/exc.pyi @@ -1,11 +1,11 @@ -from typing import Any +from typing import Any, ClassVar class HasDescriptionCode: - code: Any - def __init__(self, *arg, **kw) -> None: ... + code: str | None + def __init__(self, *arg: Any, code: str | None = ..., **kw: Any) -> None: ... class SQLAlchemyError(HasDescriptionCode, Exception): - def __unicode__(self): ... + def __unicode__(self) -> str: ... class ArgumentError(SQLAlchemyError): ... @@ -30,8 +30,8 @@ class UnsupportedCompilationError(CompileError): code: str compiler: Any element_type: Any - message: Any - def __init__(self, compiler, element_type, message: Any | None = ...) -> None: ... + message: str | None + def __init__(self, compiler, element_type, message: str | None = ...) -> None: ... def __reduce__(self): ... class IdentifierError(SQLAlchemyError): ... @@ -114,41 +114,26 @@ class DBAPIError(StatementError): ismulti: Any | None = ..., ) -> None: ... -class InterfaceError(DBAPIError): - code: str - -class DatabaseError(DBAPIError): - code: str - -class DataError(DatabaseError): - code: str - -class OperationalError(DatabaseError): - code: str - -class IntegrityError(DatabaseError): - code: str - -class InternalError(DatabaseError): - code: str - -class ProgrammingError(DatabaseError): - code: str - -class NotSupportedError(DatabaseError): - code: str +class InterfaceError(DBAPIError): ... +class DatabaseError(DBAPIError): ... +class DataError(DatabaseError): ... +class OperationalError(DatabaseError): ... +class IntegrityError(DatabaseError): ... +class InternalError(DatabaseError): ... +class ProgrammingError(DatabaseError): ... +class NotSupportedError(DatabaseError): ... class SADeprecationWarning(HasDescriptionCode, DeprecationWarning): - deprecated_since: Any + deprecated_since: ClassVar[str | None] class Base20DeprecationWarning(SADeprecationWarning): - deprecated_since: str + deprecated_since: ClassVar[str] class LegacyAPIWarning(Base20DeprecationWarning): ... class RemovedIn20Warning(Base20DeprecationWarning): ... class MovedIn20Warning(RemovedIn20Warning): ... class SAPendingDeprecationWarning(PendingDeprecationWarning): - deprecated_since: Any + deprecated_since: ClassVar[str | None] class SAWarning(HasDescriptionCode, RuntimeWarning): ... diff --git a/stubs/SQLAlchemy/sqlalchemy/ext/declarative/__init__.pyi b/stubs/SQLAlchemy/sqlalchemy/ext/declarative/__init__.pyi index 488bd549d..7eddf8603 100644 --- a/stubs/SQLAlchemy/sqlalchemy/ext/declarative/__init__.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/ext/declarative/__init__.pyi @@ -1,4 +1,11 @@ -from ...orm.decl_api import DeclarativeMeta as DeclarativeMeta, declared_attr as declared_attr +from ...orm.decl_api import ( + DeclarativeMeta as DeclarativeMeta, + as_declarative as as_declarative, + declarative_base as declarative_base, + declared_attr as declared_attr, + has_inherited_table as has_inherited_table, + synonym_for as synonym_for, +) from .extensions import ( AbstractConcreteBase as AbstractConcreteBase, ConcreteBase as ConcreteBase, @@ -18,8 +25,3 @@ __all__ = [ "DeclarativeMeta", "DeferredReflection", ] - -def declarative_base(*arg, **kw): ... -def as_declarative(*arg, **kw): ... -def has_inherited_table(*arg, **kw): ... -def synonym_for(*arg, **kw): ... diff --git a/stubs/SQLAlchemy/sqlalchemy/log.pyi b/stubs/SQLAlchemy/sqlalchemy/log.pyi index 515fd8601..2f1115228 100644 --- a/stubs/SQLAlchemy/sqlalchemy/log.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/log.pyi @@ -1,20 +1,26 @@ -from typing import Any +from _typeshed import Self +from logging import Logger +from typing import Any, TypeVar, overload +from typing_extensions import Literal + +_ClsT = TypeVar("_ClsT", bound=type) +_EchoFlag = bool | Literal["debug"] | None rootlogger: Any -def class_logger(cls): ... +def class_logger(cls: _ClsT) -> _ClsT: ... class Identified: - logging_name: Any + logging_name: str | None class InstanceLogger: - echo: Any - logger: Any - def __init__(self, echo, name) -> None: ... + echo: _EchoFlag + logger: Logger + def __init__(self, echo: _EchoFlag, name: str | None) -> None: ... def debug(self, msg, *args, **kwargs) -> None: ... def info(self, msg, *args, **kwargs) -> None: ... def warning(self, msg, *args, **kwargs) -> None: ... - warn: Any + warn = warning def error(self, msg, *args, **kwargs) -> None: ... def exception(self, msg, *args, **kwargs) -> None: ... def critical(self, msg, *args, **kwargs) -> None: ... @@ -22,9 +28,12 @@ class InstanceLogger: def isEnabledFor(self, level): ... def getEffectiveLevel(self): ... -def instance_logger(instance, echoflag: Any | None = ...) -> None: ... +def instance_logger(instance: Identified, echoflag: _EchoFlag = ...) -> None: ... class echo_property: __doc__: str - def __get__(self, instance, owner): ... - def __set__(self, instance, value) -> None: ... + @overload + def __get__(self: Self, instance: None, owner: object) -> Self: ... + @overload + def __get__(self, instance: Identified, owner: object) -> _EchoFlag: ... + def __set__(self, instance: Identified, value: _EchoFlag) -> None: ... diff --git a/stubs/SQLAlchemy/sqlalchemy/util/_collections.pyi b/stubs/SQLAlchemy/sqlalchemy/util/_collections.pyi index 08b0fb8af..49d369ea1 100644 --- a/stubs/SQLAlchemy/sqlalchemy/util/_collections.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/util/_collections.pyi @@ -1,21 +1,27 @@ import collections.abc import sys -from typing import Any +from _typeshed import Self, SupportsKeysAndGetItem +from collections.abc import Callable, Iterable, Iterator +from typing import Any, Generic, NoReturn, TypeVar, overload from ..cimmutabledict import immutabledict as immutabledict +from ..sql.elements import ColumnElement + +_S = TypeVar("_S") +_T = TypeVar("_T") collections_abc = collections.abc -EMPTY_SET: Any +EMPTY_SET: frozenset[Any] class ImmutableContainer: - __delitem__: Any - __setitem__: Any - __setattr__: Any + def __delitem__(self, *arg: object, **kw: object) -> NoReturn: ... + def __setitem__(self, *arg: object, **kw: object) -> NoReturn: ... + def __setattr__(self, *arg: object, **kw: object) -> NoReturn: ... -def coerce_to_immutabledict(d): ... +def coerce_to_immutabledict(d) -> immutabledict: ... -EMPTY_DICT: Any +EMPTY_DICT: immutabledict class FacadeDict(ImmutableContainer, dict[Any, Any]): clear: Any @@ -27,31 +33,34 @@ class FacadeDict(ImmutableContainer, dict[Any, Any]): def copy(self) -> None: ... # type: ignore[override] def __reduce__(self): ... -class Properties: - def __init__(self, data) -> None: ... - def __len__(self): ... - def __iter__(self): ... - def __dir__(self): ... - def __add__(self, other): ... - def __setitem__(self, key, obj) -> None: ... - def __getitem__(self, key): ... - def __delitem__(self, key) -> None: ... - def __setattr__(self, key, obj) -> None: ... - def __getattr__(self, key): ... - def __contains__(self, key): ... - def as_immutable(self): ... - def update(self, value) -> None: ... - def get(self, key, default: Any | None = ...): ... - def keys(self): ... - def values(self): ... - def items(self): ... - def has_key(self, key): ... +class Properties(Generic[_T]): + def __init__(self, data: dict[str, _T]) -> None: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[_T]: ... + def __dir__(self) -> list[str]: ... + def __add__(self, other: Iterable[_S]) -> list[_S | _T]: ... + def __setitem__(self, key: str, obj: _T) -> None: ... + def __getitem__(self, key: str) -> _T: ... + def __delitem__(self, key: str) -> None: ... + def __setattr__(self, key: str, obj: _T) -> None: ... + def __getattr__(self, key: str) -> _T: ... + def __contains__(self, key: str) -> bool: ... + def as_immutable(self) -> ImmutableProperties[_T]: ... + def update(self, value: Iterable[tuple[str, _T]] | SupportsKeysAndGetItem[str, _T]) -> None: ... + @overload + def get(self, key: str) -> _T | None: ... + @overload + def get(self, key: str, default: _S) -> _T | _S: ... + def keys(self) -> list[str]: ... + def values(self) -> list[_T]: ... + def items(self) -> list[tuple[str, _T]]: ... + def has_key(self, key: str) -> bool: ... def clear(self) -> None: ... -class OrderedProperties(Properties): +class OrderedProperties(Properties[_T], Generic[_T]): def __init__(self) -> None: ... -class ImmutableProperties(ImmutableContainer, Properties): ... +class ImmutableProperties(ImmutableContainer, Properties[_T], Generic[_T]): ... if sys.version_info >= (3, 7): OrderedDict = dict @@ -75,32 +84,32 @@ else: def sort_dictionary(d, key: Any | None = ...): ... -class OrderedSet(set[Any]): - def __init__(self, d: Any | None = ...) -> None: ... - def add(self, element) -> None: ... - def remove(self, element) -> None: ... - def insert(self, pos, element) -> None: ... - def discard(self, element) -> None: ... +class OrderedSet(set[_T], Generic[_T]): + def __init__(self, d: Iterable[_T] | None = ...) -> None: ... + def add(self, element: _T) -> None: ... + def remove(self, element: _T) -> None: ... + def insert(self, pos: int, element: _T) -> None: ... + def discard(self, element: _T) -> None: ... def clear(self) -> None: ... - def __getitem__(self, key): ... - def __iter__(self): ... - def __add__(self, other): ... - def update(self, iterable): ... - __ior__: Any - def union(self, other): ... - __or__: Any - def intersection(self, other): ... - __and__: Any - def symmetric_difference(self, other): ... - __xor__: Any - def difference(self, other): ... - __sub__: Any - def intersection_update(self, other): ... - __iand__: Any - def symmetric_difference_update(self, other): ... - __ixor__: Any - def difference_update(self, other): ... - __isub__: Any + def __getitem__(self, key: int) -> _T: ... + def __iter__(self) -> Iterator[_T]: ... + def __add__(self, other: Iterable[_S]) -> OrderedSet[_S | _T]: ... + def update(self: Self, iterable: Iterable[_T]) -> Self: ... # type: ignore[override] + __ior__ = update # type: ignore[assignment] + def union(self, other: Iterable[_S]) -> OrderedSet[_S | _T]: ... # type: ignore[override] + __or__ = union # type: ignore[assignment] + def intersection(self: Self, other: Iterable[Any]) -> Self: ... # type: ignore[override] + __and__ = intersection # type: ignore[assignment] + def symmetric_difference(self, other: Iterable[_S]) -> OrderedSet[_S | _T]: ... + __xor__ = symmetric_difference # type: ignore[assignment] + def difference(self: Self, other: Iterable[Any]) -> Self: ... # type: ignore[override] + __sub__ = difference # type: ignore[assignment] + def intersection_update(self: Self, other: Iterable[Any]) -> Self: ... # type: ignore[override] + __iand__ = intersection_update # type: ignore[assignment] + def symmetric_difference_update(self: Self, other: Iterable[_T]) -> Self: ... # type: ignore[override] + __ixor__ = symmetric_difference_update # type: ignore[assignment] + def difference_update(self: Self, other: Iterable[Any]) -> Self: ... # type: ignore[override] + __isub__ = difference_update # type: ignore[assignment] class IdentitySet: def __init__(self, iterable: Any | None = ...) -> None: ... @@ -164,9 +173,9 @@ class WeakPopulateDict(dict[Any, Any]): column_set = set column_dict = dict -ordered_column_set = OrderedSet +ordered_column_set = OrderedSet[ColumnElement] -def unique_list(seq, hashfunc: Any | None = ...): ... +def unique_list(seq: Iterable[_T], hashfunc: Callable[[_T], Any] | None = ...) -> list[_T]: ... class UniqueAppender: data: Any diff --git a/tests/pytype_exclude_list.txt b/tests/pytype_exclude_list.txt index d86fd9624..6d303bbc5 100644 --- a/tests/pytype_exclude_list.txt +++ b/tests/pytype_exclude_list.txt @@ -49,6 +49,7 @@ stubs/SQLAlchemy/sqlalchemy/databases/__init__.pyi stubs/SQLAlchemy/sqlalchemy/dialects/__init__.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mssql/__init__.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mssql/base.pyi +stubs/SQLAlchemy/sqlalchemy/dialects/mssql/information_schema.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mssql/json.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mssql/mxodbc.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mssql/pymssql.pyi @@ -58,6 +59,8 @@ stubs/SQLAlchemy/sqlalchemy/dialects/mysql/aiomysql.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mysql/asyncmy.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mysql/base.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mysql/cymysql.pyi +stubs/SQLAlchemy/sqlalchemy/dialects/mysql/dml.pyi +stubs/SQLAlchemy/sqlalchemy/dialects/mysql/expression.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mysql/json.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mysql/mariadb.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mysql/mariadbconnector.pyi @@ -66,9 +69,13 @@ stubs/SQLAlchemy/sqlalchemy/dialects/mysql/mysqldb.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mysql/oursql.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mysql/pymysql.pyi stubs/SQLAlchemy/sqlalchemy/dialects/mysql/pyodbc.pyi +stubs/SQLAlchemy/sqlalchemy/dialects/oracle/base.pyi stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/__init__.pyi stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/array.pyi stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/asyncpg.pyi +stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/base.pyi +stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/dml.pyi +stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/ext.pyi stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/hstore.pyi stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/json.pyi stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/pg8000.pyi @@ -79,7 +86,27 @@ stubs/SQLAlchemy/sqlalchemy/dialects/postgresql/ranges.pyi stubs/SQLAlchemy/sqlalchemy/dialects/sqlite/__init__.pyi stubs/SQLAlchemy/sqlalchemy/dialects/sqlite/aiosqlite.pyi stubs/SQLAlchemy/sqlalchemy/dialects/sqlite/base.pyi +stubs/SQLAlchemy/sqlalchemy/dialects/sqlite/dml.pyi stubs/SQLAlchemy/sqlalchemy/dialects/sqlite/json.pyi stubs/SQLAlchemy/sqlalchemy/dialects/sqlite/pysqlcipher.pyi stubs/SQLAlchemy/sqlalchemy/dialects/sqlite/pysqlite.pyi +stubs/SQLAlchemy/sqlalchemy/engine/cursor.pyi +stubs/SQLAlchemy/sqlalchemy/engine/result.pyi stubs/SQLAlchemy/sqlalchemy/engine/row.pyi +stubs/SQLAlchemy/sqlalchemy/ext/asyncio/result.pyi +stubs/SQLAlchemy/sqlalchemy/ext/horizontal_shard.pyi +stubs/SQLAlchemy/sqlalchemy/orm/attributes.pyi +stubs/SQLAlchemy/sqlalchemy/orm/descriptor_props.pyi +stubs/SQLAlchemy/sqlalchemy/orm/dynamic.pyi +stubs/SQLAlchemy/sqlalchemy/orm/mapper.pyi +stubs/SQLAlchemy/sqlalchemy/orm/query.pyi +stubs/SQLAlchemy/sqlalchemy/orm/strategy_options.pyi +stubs/SQLAlchemy/sqlalchemy/orm/util.pyi +stubs/SQLAlchemy/sqlalchemy/sql/compiler.pyi +stubs/SQLAlchemy/sqlalchemy/sql/crud.pyi +stubs/SQLAlchemy/sqlalchemy/sql/ddl.pyi +stubs/SQLAlchemy/sqlalchemy/sql/elements.pyi +stubs/SQLAlchemy/sqlalchemy/sql/functions.pyi +stubs/SQLAlchemy/sqlalchemy/sql/lambdas.pyi +stubs/SQLAlchemy/sqlalchemy/sql/schema.pyi +stubs/SQLAlchemy/sqlalchemy/sql/selectable.pyi