From cb0fbd891336df9410d297de863fb3dd9731067c Mon Sep 17 00:00:00 2001 From: Eli$ Date: Mon, 8 Sep 2025 16:38:07 +0200 Subject: [PATCH] [PyMySQL] Improve `Connection.__init__` overloads & add missing types (#14684) --- stubs/PyMySQL/pymysql/connections.pyi | 330 +++++++++++++++++--------- 1 file changed, 220 insertions(+), 110 deletions(-) diff --git a/stubs/PyMySQL/pymysql/connections.pyi b/stubs/PyMySQL/pymysql/connections.pyi index eab248848..2fc19397c 100644 --- a/stubs/PyMySQL/pymysql/connections.pyi +++ b/stubs/PyMySQL/pymysql/connections.pyi @@ -1,13 +1,25 @@ -from _typeshed import FileDescriptorOrPath, Incomplete -from collections.abc import Mapping -from socket import socket as _socket -from ssl import _PasswordType -from typing import AnyStr, Generic, TypeVar, overload +from _typeshed import FileDescriptorOrPath, Incomplete, Unused +from collections.abc import Callable, Mapping +from socket import _Address, socket as _socket +from ssl import SSLContext, _PasswordType +from typing import Any, AnyStr, Generic, TypeVar, overload from typing_extensions import Self, deprecated from .charset import charset_by_id as charset_by_id, charset_by_name as charset_by_name from .constants import CLIENT as CLIENT, COMMAND as COMMAND, FIELD_TYPE as FIELD_TYPE, SERVER_STATUS as SERVER_STATUS from .cursors import Cursor +from .err import ( + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, + Warning, +) _C = TypeVar("_C", bound=Cursor) _C2 = TypeVar("_C2", bound=Cursor) @@ -23,121 +35,218 @@ def dump_packet(data): ... def _lenenc_int(i: int) -> bytes: ... class Connection(Generic[_C]): - ssl: Incomplete - host: Incomplete - port: Incomplete - user: Incomplete - password: Incomplete - db: Incomplete - unix_socket: Incomplete + ssl: bool + host: str + port: int + user: str | bytes | None + password: bytes + db: str | bytes | None + unix_socket: _Address | None charset: str collation: str | None - bind_address: Incomplete - use_unicode: Incomplete - client_flag: Incomplete - cursorclass: Incomplete - connect_timeout: Incomplete - messages: Incomplete - encoders: Incomplete - decoders: Incomplete - host_info: Incomplete - sql_mode: Incomplete - init_command: Incomplete + bind_address: str | None + use_unicode: bool + client_flag: int + cursorclass: type[_C] + connect_timeout: float | None + host_info: str + sql_mode: str | None + init_command: str | None max_allowed_packet: int - server_public_key: bytes + server_public_key: bytes | None + encoding: str + autocommit_mode: bool | None + encoders: dict[type[Any], Callable[[Any], str]] # argument type depends on the key + decoders: dict[int, Callable[[str], Any]] # return type depends on the key + @overload def __init__( self: Connection[Cursor], # different between overloads *, + user: str | bytes | None = None, + password: str | bytes = "", host: str | None = None, - user=None, - password: str = "", - database=None, + database: str | bytes | None = None, + unix_socket: _Address | None = None, port: int = 0, - unix_socket=None, charset: str = "", collation: str | None = None, - sql_mode=None, - read_default_file=None, - conv=None, - use_unicode: bool | None = True, + sql_mode: str | None = None, + read_default_file: str | None = None, + conv: dict[int | type[Any], Callable[[Any], str] | Callable[[str], Any]] | None = None, + use_unicode: bool = True, client_flag: int = 0, cursorclass: None = None, # different between overloads - init_command=None, - connect_timeout: int | None = 10, - ssl: Mapping[Incomplete, Incomplete] | None = None, - ssl_ca=None, - ssl_cert=None, - ssl_disabled=None, - ssl_key=None, - ssl_key_password: _PasswordType | None = None, - ssl_verify_cert=None, - ssl_verify_identity=None, - read_default_group=None, - compress=None, - named_pipe=None, + init_command: str | None = None, + connect_timeout: float = 10, + read_default_group: str | None = None, autocommit: bool | None = False, - db=None, - passwd=None, - local_infile: Incomplete | None = False, - max_allowed_packet: int = 16777216, - defer_connect: bool | None = False, - auth_plugin_map: Mapping[Incomplete, Incomplete] | None = None, + local_infile: bool = False, + max_allowed_packet: int = 16_777_216, + defer_connect: bool = False, + auth_plugin_map: dict[str, Callable[[Connection[Any]], Any]] | None = None, read_timeout: float | None = None, write_timeout: float | None = None, - bind_address=None, - binary_prefix: bool | None = False, - program_name=None, + bind_address: str | None = None, + binary_prefix: bool = False, + program_name: str | None = None, server_public_key: bytes | None = None, + ssl: dict[str, Incomplete] | SSLContext | None = None, + ssl_ca: str | None = None, + ssl_cert: str | None = None, + ssl_disabled: bool | None = None, + ssl_key: str | None = None, + ssl_key_password: _PasswordType | None = None, + ssl_verify_cert: bool | None = None, + ssl_verify_identity: bool | None = None, + compress: Unused = None, + named_pipe: Unused = None, + # different between overloads: + passwd: None = None, # deprecated + db: None = None, # deprecated ) -> None: ... @overload def __init__( - # different between overloads: + # different between overloads self: Connection[_C], # pyright: ignore[reportInvalidTypeVarUse] #11780 *, + user: str | bytes | None = None, + password: str | bytes = "", host: str | None = None, - user=None, - password: str = "", - database=None, + database: str | bytes | None = None, + unix_socket: _Address | None = None, port: int = 0, - unix_socket=None, charset: str = "", collation: str | None = None, - sql_mode=None, - read_default_file=None, - conv=None, - use_unicode: bool | None = True, + sql_mode: str | None = None, + read_default_file: str | None = None, + conv: dict[int | type[Any], Callable[[Any], str] | Callable[[str], Any]] | None = None, + use_unicode: bool = True, client_flag: int = 0, cursorclass: type[_C] = ..., # different between overloads - init_command=None, - connect_timeout: int | None = 10, - ssl: Mapping[Incomplete, Incomplete] | None = None, - ssl_ca=None, - ssl_cert=None, - ssl_disabled=None, - ssl_key=None, - ssl_verify_cert=None, - ssl_verify_identity=None, - read_default_group=None, - compress=None, - named_pipe=None, + init_command: str | None = None, + connect_timeout: float = 10, + read_default_group: str | None = None, autocommit: bool | None = False, - db=None, - passwd=None, - local_infile: Incomplete | None = False, - max_allowed_packet: int = 16777216, - defer_connect: bool | None = False, - auth_plugin_map: Mapping[Incomplete, Incomplete] | None = None, + local_infile: bool = False, + max_allowed_packet: int = 16_777_216, + defer_connect: bool = False, + auth_plugin_map: dict[str, Callable[[Connection[Any]], Any]] | None = None, read_timeout: float | None = None, write_timeout: float | None = None, - bind_address=None, - binary_prefix: bool | None = False, - program_name=None, + bind_address: str | None = None, + binary_prefix: bool = False, + program_name: str | None = None, server_public_key: bytes | None = None, + ssl: dict[str, Incomplete] | SSLContext | None = None, + ssl_ca: str | None = None, + ssl_cert: str | None = None, + ssl_disabled: bool | None = None, + ssl_key: str | None = None, + ssl_key_password: _PasswordType | None = None, + ssl_verify_cert: bool | None = None, + ssl_verify_identity: bool | None = None, + compress: Unused = None, + named_pipe: Unused = None, + # different between overloads: + passwd: None = None, # deprecated + db: None = None, # deprecated + ) -> None: ... + @overload + @deprecated("'passwd' and 'db' arguments are deprecated. Use 'password' and 'database' instead.") + def __init__( + self: Connection[Cursor], # different between overloads + *, + user: str | bytes | None = None, + password: str | bytes = "", + host: str | None = None, + database: str | bytes | None = None, + unix_socket: _Address | None = None, + port: int = 0, + charset: str = "", + collation: str | None = None, + sql_mode: str | None = None, + read_default_file: str | None = None, + conv: dict[int | type[Any], Callable[[Any], str] | Callable[[str], Any]] | None = None, + use_unicode: bool = True, + client_flag: int = 0, + cursorclass: None = None, # different between overloads + init_command: str | None = None, + connect_timeout: float = 10, + read_default_group: str | None = None, + autocommit: bool | None = False, + local_infile: bool = False, + max_allowed_packet: int = 16_777_216, + defer_connect: bool = False, + auth_plugin_map: dict[str, Callable[[Connection[Any]], Any]] | None = None, + read_timeout: float | None = None, + write_timeout: float | None = None, + bind_address: str | None = None, + binary_prefix: bool = False, + program_name: str | None = None, + server_public_key: bytes | None = None, + ssl: dict[str, Incomplete] | SSLContext | None = None, + ssl_ca: str | None = None, + ssl_cert: str | None = None, + ssl_disabled: bool | None = None, + ssl_key: str | None = None, + ssl_key_password: _PasswordType | None = None, + ssl_verify_cert: bool | None = None, + ssl_verify_identity: bool | None = None, + compress: Unused = None, + named_pipe: Unused = None, + # different between overloads: + passwd: str | bytes | None = None, # deprecated + db: str | bytes | None = None, # deprecated + ) -> None: ... + @overload + @deprecated("'passwd' and 'db' arguments are deprecated. Use 'password' and 'database' instead.") + def __init__( + # different between overloads + self: Connection[_C], # pyright: ignore[reportInvalidTypeVarUse] #11780 + *, + user: str | bytes | None = None, + password: str | bytes = "", + host: str | None = None, + database: str | bytes | None = None, + unix_socket: _Address | None = None, + port: int = 0, + charset: str = "", + collation: str | None = None, + sql_mode: str | None = None, + read_default_file: str | None = None, + conv: dict[int | type[Any], Callable[[Any], str] | Callable[[str], Any]] | None = None, + use_unicode: bool = True, + client_flag: int = 0, + cursorclass: type[_C] = ..., # different between overloads + init_command: str | None = None, + connect_timeout: float = 10, + read_default_group: str | None = None, + autocommit: bool | None = False, + local_infile: bool = False, + max_allowed_packet: int = 16_777_216, + defer_connect: bool = False, + auth_plugin_map: dict[str, Callable[[Connection[Any]], Any]] | None = None, + read_timeout: float | None = None, + write_timeout: float | None = None, + bind_address: str | None = None, + binary_prefix: bool = False, + program_name: str | None = None, + server_public_key: bytes | None = None, + ssl: dict[str, Incomplete] | SSLContext | None = None, + ssl_ca: str | None = None, + ssl_cert: str | None = None, + ssl_disabled: bool | None = None, + ssl_key: str | None = None, + ssl_key_password: _PasswordType | None = None, + ssl_verify_cert: bool | None = None, + ssl_verify_identity: bool | None = None, + compress: Unused = None, + named_pipe: Unused = None, + # different between overloads: + passwd: str | bytes | None = None, # deprecated + db: str | bytes | None = None, # deprecated ) -> None: ... - socket: Incomplete - rfile: Incomplete - wfile: Incomplete def close(self) -> None: ... @property def open(self) -> bool: ... @@ -148,7 +257,7 @@ class Connection(Generic[_C]): def begin(self) -> None: ... def rollback(self) -> None: ... def select_db(self, db) -> None: ... - def escape(self, obj, mapping: Mapping[Incomplete, Incomplete] | None = None): ... + def escape(self, obj, mapping: Mapping[str, Incomplete] | None = None): ... def literal(self, obj): ... def escape_string(self, s: AnyStr) -> AnyStr: ... @overload @@ -169,35 +278,36 @@ class Connection(Generic[_C]): def insert_id(self): ... def thread_id(self): ... def character_set_name(self): ... - def get_host_info(self): ... + def get_host_info(self) -> str: ... def get_proto_info(self): ... def get_server_info(self): ... def show_warnings(self): ... def __enter__(self) -> Self: ... def __exit__(self, *exc_info: object) -> None: ... - Warning: Incomplete - Error: Incomplete - InterfaceError: Incomplete - DatabaseError: Incomplete - DataError: Incomplete - OperationalError: Incomplete - IntegrityError: Incomplete - InternalError: Incomplete - ProgrammingError: Incomplete - NotSupportedError: Incomplete + Warning: type[Warning] + Error: type[Error] + InterfaceError: type[InterfaceError] + DatabaseError: type[DatabaseError] + DataError: type[DataError] + OperationalError: type[OperationalError] + IntegrityError: type[IntegrityError] + InternalError: type[InternalError] + ProgrammingError: type[ProgrammingError] + NotSupportedError: type[NotSupportedError] class MySQLResult: - connection: Incomplete - affected_rows: Incomplete - insert_id: Incomplete - server_status: Incomplete - warning_count: Incomplete - message: Incomplete - field_count: Incomplete + connection: Connection[Any] | None + affected_rows: int | None + insert_id: int | None + server_status: int | None + warning_count: int + message: str | None + field_count: int description: Incomplete rows: Incomplete - has_next: Incomplete - def __init__(self, connection: Connection[Incomplete]) -> None: ... + has_next: bool | None + unbuffered_active: bool + def __init__(self, connection: Connection[Any]) -> None: ... def __del__(self) -> None: ... first_packet: Incomplete def read(self) -> None: ... @@ -205,6 +315,6 @@ class MySQLResult: class LoadLocalFile: filename: FileDescriptorOrPath - connection: Connection[Incomplete] - def __init__(self, filename: FileDescriptorOrPath, connection: Connection[Incomplete]) -> None: ... + connection: Connection[Any] + def __init__(self, filename: FileDescriptorOrPath, connection: Connection[Any]) -> None: ... def send_data(self) -> None: ...