diff --git a/stubs/PyMySQL/@tests/stubtest_allowlist.txt b/stubs/PyMySQL/@tests/stubtest_allowlist.txt index 7025f9aee..1bb567319 100644 --- a/stubs/PyMySQL/@tests/stubtest_allowlist.txt +++ b/stubs/PyMySQL/@tests/stubtest_allowlist.txt @@ -1,5 +1,3 @@ -pymysql.Connect -pymysql.connections.Connection.__init__ pymysql.connections.byte2int pymysql.connections.int2byte pymysql.connections.lenenc_int @@ -39,6 +37,9 @@ pymysql.converters.escape_time pymysql.converters.escape_timedelta pymysql.converters.escape_unicode pymysql.cursors.Cursor.__del__ +# DictCursorMixin changes method types of inherited classes, but doesn't contain much at runtime +pymysql.cursors.DictCursorMixin.__iter__ +pymysql.cursors.DictCursorMixin.fetch[a-z]* pymysql.err.ER pymysql.escape_dict pymysql.escape_sequence diff --git a/stubs/PyMySQL/METADATA.toml b/stubs/PyMySQL/METADATA.toml index 31f638bf3..f339dc7fd 100644 --- a/stubs/PyMySQL/METADATA.toml +++ b/stubs/PyMySQL/METADATA.toml @@ -1,2 +1,2 @@ -version = "0.1" +version = "1.0" python2 = true diff --git a/stubs/PyMySQL/pymysql/__init__.pyi b/stubs/PyMySQL/pymysql/__init__.pyi index 35532caa8..6cfbf75b6 100644 --- a/stubs/PyMySQL/pymysql/__init__.pyi +++ b/stubs/PyMySQL/pymysql/__init__.pyi @@ -1,7 +1,7 @@ import sys -from typing import Callable, FrozenSet, Tuple +from typing import FrozenSet, Tuple -from .connections import Connection as _Connection +from .connections import Connection as Connection from .constants import FIELD_TYPE as FIELD_TYPE from .converters import escape_dict as escape_dict, escape_sequence as escape_sequence, escape_string as escape_string from .err import ( @@ -50,14 +50,15 @@ if sys.version_info >= (3, 0): else: def Binary(x) -> bytearray: ... -def Connect(*args, **kwargs) -> _Connection: ... def get_client_info() -> str: ... -connect: Callable[..., _Connection] -Connection: Callable[..., _Connection] __version__: str version_info: Tuple[int, int, int, str, int] NULL: str +# pymysql/__init__.py says "Connect = connect = Connection = connections.Connection" +Connect = Connection +connect = Connection + def thread_safe() -> bool: ... def install_as_MySQLdb() -> None: ... diff --git a/stubs/PyMySQL/pymysql/connections.pyi b/stubs/PyMySQL/pymysql/connections.pyi index 4f8f505d1..f30e03661 100644 --- a/stubs/PyMySQL/pymysql/connections.pyi +++ b/stubs/PyMySQL/pymysql/connections.pyi @@ -1,5 +1,5 @@ from socket import socket as _socket -from typing import Any, AnyStr, Mapping, Optional, Tuple, Type +from typing import Any, AnyStr, Generic, Mapping, Tuple, Type, TypeVar, overload from .charset import charset_by_id as charset_by_id, charset_by_name as charset_by_name from .constants import CLIENT as CLIENT, COMMAND as COMMAND, FIELD_TYPE as FIELD_TYPE, SERVER_STATUS as SERVER_STATUS @@ -11,6 +11,9 @@ DEFAULT_USER: Any DEBUG: Any DEFAULT_CHARSET: Any +_C = TypeVar("_C", bound=Cursor) +_C2 = TypeVar("_C2", bound=Cursor) + def dump_packet(data): ... def pack_int24(n): ... def lenenc_int(i: int) -> bytes: ... @@ -49,7 +52,7 @@ class FieldDescriptorPacket(MysqlPacket): def description(self): ... def get_column_length(self): ... -class Connection: +class Connection(Generic[_C]): ssl: Any host: Any port: Any @@ -71,40 +74,91 @@ class Connection: init_command: Any max_allowed_packet: int server_public_key: bytes + @overload def __init__( - self, - host: Optional[str] = ..., - user: Optional[Any] = ..., + self: Connection[Cursor], # different between overloads + *, + host: str | None = ..., + user: Any | None = ..., password: str = ..., - database: Optional[Any] = ..., + database: Any | None = ..., port: int = ..., - unix_socket: Optional[Any] = ..., + unix_socket: Any | None = ..., charset: str = ..., - sql_mode: Optional[Any] = ..., - read_default_file: Optional[Any] = ..., + sql_mode: Any | None = ..., + read_default_file: Any | None = ..., conv=..., - use_unicode: Optional[bool] = ..., + use_unicode: bool | None = ..., client_flag: int = ..., - cursorclass: Optional[Type[Cursor]] = ..., - init_command: Optional[Any] = ..., - connect_timeout: Optional[int] = ..., + cursorclass: None = ..., # different between overloads + init_command: Any | None = ..., + connect_timeout: int | None = ..., ssl: Mapping[Any, Any] | None = ..., - read_default_group: Optional[Any] = ..., - compress: Optional[Any] = ..., - named_pipe: Optional[Any] = ..., - autocommit: Optional[bool] = ..., - db: Optional[Any] = ..., - passwd: Optional[Any] = ..., - local_infile: Optional[Any] = ..., + ssl_ca=..., + ssl_cert=..., + ssl_disabled=..., + ssl_key=..., + ssl_verify_cert=..., + ssl_verify_identity=..., + read_default_group: Any | None = ..., + compress: Any | None = ..., + named_pipe: Any | None = ..., + autocommit: bool | None = ..., + db: Any | None = ..., + passwd: Any | None = ..., + local_infile: Any | None = ..., max_allowed_packet: int = ..., - defer_connect: Optional[bool] = ..., + defer_connect: bool | None = ..., auth_plugin_map: Mapping[Any, Any] | None = ..., - read_timeout: Optional[float] = ..., - write_timeout: Optional[float] = ..., - bind_address: Optional[Any] = ..., - binary_prefix: Optional[bool] = ..., - program_name: Optional[Any] = ..., - server_public_key: Optional[bytes] = ..., + read_timeout: float | None = ..., + write_timeout: float | None = ..., + bind_address: Any | None = ..., + binary_prefix: bool | None = ..., + program_name: Any | None = ..., + server_public_key: bytes | None = ..., + ): ... + @overload + def __init__( + self: Connection[_C], # different between overloads + *, + host: str | None = ..., + user: Any | None = ..., + password: str = ..., + database: Any | None = ..., + port: int = ..., + unix_socket: Any | None = ..., + charset: str = ..., + sql_mode: Any | None = ..., + read_default_file: Any | None = ..., + conv=..., + use_unicode: bool | None = ..., + client_flag: int = ..., + cursorclass: Type[_C] = ..., # different between overloads + init_command: Any | None = ..., + connect_timeout: int | None = ..., + ssl: Mapping[Any, Any] | None = ..., + ssl_ca=..., + ssl_cert=..., + ssl_disabled=..., + ssl_key=..., + ssl_verify_cert=..., + ssl_verify_identity=..., + read_default_group: Any | None = ..., + compress: Any | None = ..., + named_pipe: Any | None = ..., + autocommit: bool | None = ..., + db: Any | None = ..., + passwd: Any | None = ..., + local_infile: Any | None = ..., + max_allowed_packet: int = ..., + defer_connect: bool | None = ..., + auth_plugin_map: Mapping[Any, Any] | None = ..., + read_timeout: float | None = ..., + write_timeout: float | None = ..., + bind_address: Any | None = ..., + binary_prefix: bool | None = ..., + program_name: Any | None = ..., + server_public_key: bytes | None = ..., ): ... socket: Any rfile: Any @@ -121,14 +175,17 @@ class Connection: def escape(self, obj, mapping: Mapping[Any, Any] | None = ...): ... def literal(self, obj): ... def escape_string(self, s: AnyStr) -> AnyStr: ... - def cursor(self, cursor: Optional[Type[Cursor]] = ...) -> Cursor: ... + @overload + def cursor(self, cursor: None = ...) -> _C: ... + @overload + def cursor(self, cursor: Type[_C2]) -> _C2: ... def query(self, sql, unbuffered: bool = ...) -> int: ... def next_result(self, unbuffered: bool = ...) -> int: ... def affected_rows(self): ... def kill(self, thread_id): ... def ping(self, reconnect: bool = ...) -> None: ... def set_charset(self, charset) -> None: ... - def connect(self, sock: Optional[_socket] = ...) -> None: ... + def connect(self, sock: _socket | None = ...) -> None: ... def write_packet(self, payload) -> None: ... def _read_packet(self, packet_type=...): ... def insert_id(self): ... @@ -160,13 +217,13 @@ class MySQLResult: description: Any rows: Any has_next: Any - def __init__(self, connection: Connection) -> None: ... + def __init__(self, connection: Connection[Any]) -> None: ... first_packet: Any def read(self) -> None: ... def init_unbuffered_query(self) -> None: ... class LoadLocalFile: filename: Any - connection: Connection - def __init__(self, filename: Any, connection: Connection) -> None: ... + connection: Connection[Any] + def __init__(self, filename: Any, connection: Connection[Any]) -> None: ... def send_data(self) -> None: ... diff --git a/stubs/PyMySQL/pymysql/cursors.pyi b/stubs/PyMySQL/pymysql/cursors.pyi index 6ae260d2f..7b95e1738 100644 --- a/stubs/PyMySQL/pymysql/cursors.pyi +++ b/stubs/PyMySQL/pymysql/cursors.pyi @@ -1,12 +1,11 @@ -from typing import Any, Dict, Iterable, Iterator, List, Optional, Text, Tuple, TypeVar, Union +from typing import Any, Iterable, Iterator, Optional, Text, Tuple, TypeVar from .connections import Connection -_Gen = Union[Tuple[Any, ...], Dict[Text, Any]] _SelfT = TypeVar("_SelfT") class Cursor: - connection: Connection + connection: Connection[Any] description: Tuple[Text, ...] rownumber: int rowcount: int @@ -14,7 +13,7 @@ class Cursor: messages: Any errorhandler: Any lastrowid: int - def __init__(self, connection: Connection) -> None: ... + def __init__(self, connection: Connection[Any]) -> None: ... def __del__(self) -> None: ... def close(self) -> None: ... def setinputsizes(self, *args) -> None: ... @@ -24,28 +23,28 @@ class Cursor: def execute(self, query: Text, args: object = ...) -> int: ... def executemany(self, query: Text, args: Iterable[object]) -> Optional[int]: ... def callproc(self, procname: Text, args: Iterable[Any] = ...) -> Any: ... - def fetchone(self) -> Optional[_Gen]: ... - def fetchmany(self, size: Optional[int] = ...) -> Union[Optional[_Gen], List[_Gen]]: ... - def fetchall(self) -> Optional[Tuple[_Gen, ...]]: ... def scroll(self, value: int, mode: Text = ...) -> None: ... - def __iter__(self) -> Iterator[_Gen]: ... def __enter__(self: _SelfT) -> _SelfT: ... def __exit__(self, *exc_info: Any) -> None: ... - -class DictCursor(Cursor): - def fetchone(self) -> Optional[Dict[Text, Any]]: ... - def fetchmany(self, size: Optional[int] = ...) -> Optional[Tuple[Dict[Text, Any], ...]]: ... - def fetchall(self) -> Optional[Tuple[Dict[Text, Any], ...]]: ... + # Methods returning result tuples are below. + def fetchone(self) -> Tuple[Any, ...] | None: ... + def fetchmany(self, size: int | None = ...) -> Tuple[Tuple[Any, ...], ...]: ... + def fetchall(self) -> Tuple[Tuple[Any, ...], ...]: ... + def __iter__(self) -> Iterator[Tuple[Any, ...]]: ... class DictCursorMixin: - dict_type: Any + dict_type: Any # TODO: add support if someone needs this + def fetchone(self) -> dict[Text, Any] | None: ... + def fetchmany(self, size: int | None = ...) -> Tuple[dict[Text, Any], ...]: ... + def fetchall(self) -> Tuple[dict[Text, Any], ...]: ... + def __iter__(self) -> Iterator[dict[Text, Any]]: ... class SSCursor(Cursor): - # fetchall return type is incompatible with the supertype. - def fetchall(self) -> List[_Gen]: ... # type: ignore - def fetchall_unbuffered(self) -> Iterator[_Gen]: ... - def __iter__(self) -> Iterator[_Gen]: ... - def fetchmany(self, size: Optional[int] = ...) -> List[_Gen]: ... + def fetchall(self) -> list[Tuple[Any, ...]]: ... # type: ignore + def fetchall_unbuffered(self) -> Iterator[Tuple[Any, ...]]: ... def scroll(self, value: int, mode: Text = ...) -> None: ... -class SSDictCursor(DictCursorMixin, SSCursor): ... +class DictCursor(DictCursorMixin, Cursor): ... # type: ignore + +class SSDictCursor(DictCursorMixin, SSCursor): # type: ignore + def fetchall_unbuffered(self) -> Iterator[dict[Text, Any]]: ... # type: ignore