Pymysql cursor types and fetching return types (#5652)

version update: stubs made based on pymysql 1.0.2
This commit is contained in:
Akuli
2021-06-17 21:53:22 +03:00
committed by GitHub
parent 64f463172b
commit c5b59b9e41
5 changed files with 118 additions and 60 deletions

View File

@@ -1,7 +1,7 @@
import sys
from typing import Callable, FrozenSet, Tuple
from typing import FrozenSet, Tuple
from .connections import Connection as _Connection
from .connections import Connection as Connection
from .constants import FIELD_TYPE as FIELD_TYPE
from .converters import escape_dict as escape_dict, escape_sequence as escape_sequence, escape_string as escape_string
from .err import (
@@ -50,14 +50,15 @@ if sys.version_info >= (3, 0):
else:
def Binary(x) -> bytearray: ...
def Connect(*args, **kwargs) -> _Connection: ...
def get_client_info() -> str: ...
connect: Callable[..., _Connection]
Connection: Callable[..., _Connection]
__version__: str
version_info: Tuple[int, int, int, str, int]
NULL: str
# pymysql/__init__.py says "Connect = connect = Connection = connections.Connection"
Connect = Connection
connect = Connection
def thread_safe() -> bool: ...
def install_as_MySQLdb() -> None: ...

View File

@@ -1,5 +1,5 @@
from socket import socket as _socket
from typing import Any, AnyStr, Mapping, Optional, Tuple, Type
from typing import Any, AnyStr, Generic, Mapping, Tuple, Type, TypeVar, overload
from .charset import charset_by_id as charset_by_id, charset_by_name as charset_by_name
from .constants import CLIENT as CLIENT, COMMAND as COMMAND, FIELD_TYPE as FIELD_TYPE, SERVER_STATUS as SERVER_STATUS
@@ -11,6 +11,9 @@ DEFAULT_USER: Any
DEBUG: Any
DEFAULT_CHARSET: Any
_C = TypeVar("_C", bound=Cursor)
_C2 = TypeVar("_C2", bound=Cursor)
def dump_packet(data): ...
def pack_int24(n): ...
def lenenc_int(i: int) -> bytes: ...
@@ -49,7 +52,7 @@ class FieldDescriptorPacket(MysqlPacket):
def description(self): ...
def get_column_length(self): ...
class Connection:
class Connection(Generic[_C]):
ssl: Any
host: Any
port: Any
@@ -71,40 +74,91 @@ class Connection:
init_command: Any
max_allowed_packet: int
server_public_key: bytes
@overload
def __init__(
self,
host: Optional[str] = ...,
user: Optional[Any] = ...,
self: Connection[Cursor], # different between overloads
*,
host: str | None = ...,
user: Any | None = ...,
password: str = ...,
database: Optional[Any] = ...,
database: Any | None = ...,
port: int = ...,
unix_socket: Optional[Any] = ...,
unix_socket: Any | None = ...,
charset: str = ...,
sql_mode: Optional[Any] = ...,
read_default_file: Optional[Any] = ...,
sql_mode: Any | None = ...,
read_default_file: Any | None = ...,
conv=...,
use_unicode: Optional[bool] = ...,
use_unicode: bool | None = ...,
client_flag: int = ...,
cursorclass: Optional[Type[Cursor]] = ...,
init_command: Optional[Any] = ...,
connect_timeout: Optional[int] = ...,
cursorclass: None = ..., # different between overloads
init_command: Any | None = ...,
connect_timeout: int | None = ...,
ssl: Mapping[Any, Any] | None = ...,
read_default_group: Optional[Any] = ...,
compress: Optional[Any] = ...,
named_pipe: Optional[Any] = ...,
autocommit: Optional[bool] = ...,
db: Optional[Any] = ...,
passwd: Optional[Any] = ...,
local_infile: Optional[Any] = ...,
ssl_ca=...,
ssl_cert=...,
ssl_disabled=...,
ssl_key=...,
ssl_verify_cert=...,
ssl_verify_identity=...,
read_default_group: Any | None = ...,
compress: Any | None = ...,
named_pipe: Any | None = ...,
autocommit: bool | None = ...,
db: Any | None = ...,
passwd: Any | None = ...,
local_infile: Any | None = ...,
max_allowed_packet: int = ...,
defer_connect: Optional[bool] = ...,
defer_connect: bool | None = ...,
auth_plugin_map: Mapping[Any, Any] | None = ...,
read_timeout: Optional[float] = ...,
write_timeout: Optional[float] = ...,
bind_address: Optional[Any] = ...,
binary_prefix: Optional[bool] = ...,
program_name: Optional[Any] = ...,
server_public_key: Optional[bytes] = ...,
read_timeout: float | None = ...,
write_timeout: float | None = ...,
bind_address: Any | None = ...,
binary_prefix: bool | None = ...,
program_name: Any | None = ...,
server_public_key: bytes | None = ...,
): ...
@overload
def __init__(
self: Connection[_C], # different between overloads
*,
host: str | None = ...,
user: Any | None = ...,
password: str = ...,
database: Any | None = ...,
port: int = ...,
unix_socket: Any | None = ...,
charset: str = ...,
sql_mode: Any | None = ...,
read_default_file: Any | None = ...,
conv=...,
use_unicode: bool | None = ...,
client_flag: int = ...,
cursorclass: Type[_C] = ..., # different between overloads
init_command: Any | None = ...,
connect_timeout: int | None = ...,
ssl: Mapping[Any, Any] | None = ...,
ssl_ca=...,
ssl_cert=...,
ssl_disabled=...,
ssl_key=...,
ssl_verify_cert=...,
ssl_verify_identity=...,
read_default_group: Any | None = ...,
compress: Any | None = ...,
named_pipe: Any | None = ...,
autocommit: bool | None = ...,
db: Any | None = ...,
passwd: Any | None = ...,
local_infile: Any | None = ...,
max_allowed_packet: int = ...,
defer_connect: bool | None = ...,
auth_plugin_map: Mapping[Any, Any] | None = ...,
read_timeout: float | None = ...,
write_timeout: float | None = ...,
bind_address: Any | None = ...,
binary_prefix: bool | None = ...,
program_name: Any | None = ...,
server_public_key: bytes | None = ...,
): ...
socket: Any
rfile: Any
@@ -121,14 +175,17 @@ class Connection:
def escape(self, obj, mapping: Mapping[Any, Any] | None = ...): ...
def literal(self, obj): ...
def escape_string(self, s: AnyStr) -> AnyStr: ...
def cursor(self, cursor: Optional[Type[Cursor]] = ...) -> Cursor: ...
@overload
def cursor(self, cursor: None = ...) -> _C: ...
@overload
def cursor(self, cursor: Type[_C2]) -> _C2: ...
def query(self, sql, unbuffered: bool = ...) -> int: ...
def next_result(self, unbuffered: bool = ...) -> int: ...
def affected_rows(self): ...
def kill(self, thread_id): ...
def ping(self, reconnect: bool = ...) -> None: ...
def set_charset(self, charset) -> None: ...
def connect(self, sock: Optional[_socket] = ...) -> None: ...
def connect(self, sock: _socket | None = ...) -> None: ...
def write_packet(self, payload) -> None: ...
def _read_packet(self, packet_type=...): ...
def insert_id(self): ...
@@ -160,13 +217,13 @@ class MySQLResult:
description: Any
rows: Any
has_next: Any
def __init__(self, connection: Connection) -> None: ...
def __init__(self, connection: Connection[Any]) -> None: ...
first_packet: Any
def read(self) -> None: ...
def init_unbuffered_query(self) -> None: ...
class LoadLocalFile:
filename: Any
connection: Connection
def __init__(self, filename: Any, connection: Connection) -> None: ...
connection: Connection[Any]
def __init__(self, filename: Any, connection: Connection[Any]) -> None: ...
def send_data(self) -> None: ...

View File

@@ -1,12 +1,11 @@
from typing import Any, Dict, Iterable, Iterator, List, Optional, Text, Tuple, TypeVar, Union
from typing import Any, Iterable, Iterator, Optional, Text, Tuple, TypeVar
from .connections import Connection
_Gen = Union[Tuple[Any, ...], Dict[Text, Any]]
_SelfT = TypeVar("_SelfT")
class Cursor:
connection: Connection
connection: Connection[Any]
description: Tuple[Text, ...]
rownumber: int
rowcount: int
@@ -14,7 +13,7 @@ class Cursor:
messages: Any
errorhandler: Any
lastrowid: int
def __init__(self, connection: Connection) -> None: ...
def __init__(self, connection: Connection[Any]) -> None: ...
def __del__(self) -> None: ...
def close(self) -> None: ...
def setinputsizes(self, *args) -> None: ...
@@ -24,28 +23,28 @@ class Cursor:
def execute(self, query: Text, args: object = ...) -> int: ...
def executemany(self, query: Text, args: Iterable[object]) -> Optional[int]: ...
def callproc(self, procname: Text, args: Iterable[Any] = ...) -> Any: ...
def fetchone(self) -> Optional[_Gen]: ...
def fetchmany(self, size: Optional[int] = ...) -> Union[Optional[_Gen], List[_Gen]]: ...
def fetchall(self) -> Optional[Tuple[_Gen, ...]]: ...
def scroll(self, value: int, mode: Text = ...) -> None: ...
def __iter__(self) -> Iterator[_Gen]: ...
def __enter__(self: _SelfT) -> _SelfT: ...
def __exit__(self, *exc_info: Any) -> None: ...
class DictCursor(Cursor):
def fetchone(self) -> Optional[Dict[Text, Any]]: ...
def fetchmany(self, size: Optional[int] = ...) -> Optional[Tuple[Dict[Text, Any], ...]]: ...
def fetchall(self) -> Optional[Tuple[Dict[Text, Any], ...]]: ...
# Methods returning result tuples are below.
def fetchone(self) -> Tuple[Any, ...] | None: ...
def fetchmany(self, size: int | None = ...) -> Tuple[Tuple[Any, ...], ...]: ...
def fetchall(self) -> Tuple[Tuple[Any, ...], ...]: ...
def __iter__(self) -> Iterator[Tuple[Any, ...]]: ...
class DictCursorMixin:
dict_type: Any
dict_type: Any # TODO: add support if someone needs this
def fetchone(self) -> dict[Text, Any] | None: ...
def fetchmany(self, size: int | None = ...) -> Tuple[dict[Text, Any], ...]: ...
def fetchall(self) -> Tuple[dict[Text, Any], ...]: ...
def __iter__(self) -> Iterator[dict[Text, Any]]: ...
class SSCursor(Cursor):
# fetchall return type is incompatible with the supertype.
def fetchall(self) -> List[_Gen]: ... # type: ignore
def fetchall_unbuffered(self) -> Iterator[_Gen]: ...
def __iter__(self) -> Iterator[_Gen]: ...
def fetchmany(self, size: Optional[int] = ...) -> List[_Gen]: ...
def fetchall(self) -> list[Tuple[Any, ...]]: ... # type: ignore
def fetchall_unbuffered(self) -> Iterator[Tuple[Any, ...]]: ...
def scroll(self, value: int, mode: Text = ...) -> None: ...
class SSDictCursor(DictCursorMixin, SSCursor): ...
class DictCursor(DictCursorMixin, Cursor): ... # type: ignore
class SSDictCursor(DictCursorMixin, SSCursor): # type: ignore
def fetchall_unbuffered(self) -> Iterator[dict[Text, Any]]: ... # type: ignore