ssl, socket, array: Improve bytes handling (#8997)

This commit is contained in:
Jelle Zijlstra
2022-10-28 06:35:51 -07:00
committed by GitHub
parent bcd876f77e
commit 287fce4872
4 changed files with 53 additions and 38 deletions

View File

@@ -15,10 +15,10 @@ _CMSG: TypeAlias = tuple[int, int, bytes]
_CMSGArg: TypeAlias = tuple[int, int, ReadableBuffer]
# Addresses can be either tuples of varying lengths (AF_INET, AF_INET6,
# AF_NETLINK, AF_TIPC) or strings (AF_UNIX).
_Address: TypeAlias = tuple[Any, ...] | str
# AF_NETLINK, AF_TIPC) or strings/buffers (AF_UNIX).
# See getsockaddrarg() in socketmodule.c.
_Address: TypeAlias = tuple[Any, ...] | str | ReadableBuffer
_RetAddress: TypeAlias = Any
# TODO Most methods allow bytes as address objects
# ----- Constants -----
# Some socket families are listed in the "Socket families" section of the docs,
@@ -584,10 +584,10 @@ class socket:
@property
def timeout(self) -> float | None: ...
def __init__(self, family: int = ..., type: int = ..., proto: int = ..., fileno: _FD | None = ...) -> None: ...
def bind(self, __address: _Address | bytes) -> None: ...
def bind(self, __address: _Address) -> None: ...
def close(self) -> None: ...
def connect(self, __address: _Address | bytes) -> None: ...
def connect_ex(self, __address: _Address | bytes) -> int: ...
def connect(self, __address: _Address) -> None: ...
def connect_ex(self, __address: _Address) -> int: ...
def detach(self) -> int: ...
def fileno(self) -> int: ...
def getpeername(self) -> _RetAddress: ...
@@ -634,7 +634,7 @@ class socket:
def setblocking(self, __flag: bool) -> None: ...
def settimeout(self, __value: float | None) -> None: ...
@overload
def setsockopt(self, __level: int, __optname: int, __value: int | bytes) -> None: ...
def setsockopt(self, __level: int, __optname: int, __value: int | ReadableBuffer) -> None: ...
@overload
def setsockopt(self, __level: int, __optname: int, __value: None, __optlen: int) -> None: ...
if sys.platform == "win32":
@@ -671,9 +671,9 @@ def ntohs(__x: int) -> int: ... # param & ret val are 16-bit ints
def htonl(__x: int) -> int: ... # param & ret val are 32-bit ints
def htons(__x: int) -> int: ... # param & ret val are 16-bit ints
def inet_aton(__ip_string: str) -> bytes: ... # ret val 4 bytes in length
def inet_ntoa(__packed_ip: bytes) -> str: ...
def inet_ntoa(__packed_ip: ReadableBuffer) -> str: ...
def inet_pton(__address_family: int, __ip_string: str) -> bytes: ...
def inet_ntop(__address_family: int, __packed_ip: bytes) -> str: ...
def inet_ntop(__address_family: int, __packed_ip: ReadableBuffer) -> str: ...
def getdefaulttimeout() -> float | None: ...
def setdefaulttimeout(__timeout: float | None) -> None: ...

View File

@@ -21,15 +21,19 @@ class array(MutableSequence[_T], Generic[_T]):
@property
def itemsize(self) -> int: ...
@overload
def __init__(self: array[int], __typecode: _IntTypeCode, __initializer: bytes | Iterable[int] = ...) -> None: ...
def __init__(self: array[int], __typecode: _IntTypeCode, __initializer: bytes | bytearray | Iterable[int] = ...) -> None: ...
@overload
def __init__(self: array[float], __typecode: _FloatTypeCode, __initializer: bytes | Iterable[float] = ...) -> None: ...
def __init__(
self: array[float], __typecode: _FloatTypeCode, __initializer: bytes | bytearray | Iterable[float] = ...
) -> None: ...
@overload
def __init__(self: array[str], __typecode: _UnicodeTypeCode, __initializer: bytes | Iterable[str] = ...) -> None: ...
def __init__(
self: array[str], __typecode: _UnicodeTypeCode, __initializer: bytes | bytearray | Iterable[str] = ...
) -> None: ...
@overload
def __init__(self, __typecode: str, __initializer: Iterable[_T]) -> None: ...
@overload
def __init__(self, __typecode: str, __initializer: bytes = ...) -> None: ...
def __init__(self, __typecode: str, __initializer: bytes | bytearray = ...) -> None: ...
def append(self, __v: _T) -> None: ...
def buffer_info(self) -> tuple[int, int]: ...
def byteswap(self) -> None: ...

View File

@@ -738,7 +738,7 @@ if sys.platform != "win32":
if sys.version_info >= (3, 9):
# flags and address appear to be unused in send_fds and recv_fds
def send_fds(
sock: socket, buffers: Iterable[bytes], fds: bytes | Iterable[int], flags: int = ..., address: None = ...
sock: socket, buffers: Iterable[ReadableBuffer], fds: Iterable[int], flags: int = ..., address: None = ...
) -> int: ...
def recv_fds(sock: socket, bufsize: int, maxfds: int, flags: int = ...) -> tuple[bytes, list[int], int, Any]: ...
@@ -768,16 +768,14 @@ if sys.version_info >= (3, 11):
def create_connection(
address: tuple[str | None, int],
timeout: float | None = ..., # noqa: F811
source_address: tuple[bytearray | bytes | str, int] | None = ...,
source_address: _Address | None = ...,
*,
all_errors: bool = ...,
) -> socket: ...
else:
def create_connection(
address: tuple[str | None, int],
timeout: float | None = ..., # noqa: F811
source_address: tuple[bytearray | bytes | str, int] | None = ...,
address: tuple[str | None, int], timeout: float | None = ..., source_address: _Address | None = ... # noqa: F811
) -> socket: ...
if sys.version_info >= (3, 8):
@@ -788,5 +786,10 @@ if sys.version_info >= (3, 8):
# the 5th tuple item is an address
def getaddrinfo(
host: bytes | str | None, port: str | int | None, family: int = ..., type: int = ..., proto: int = ..., flags: int = ...
host: bytes | str | None,
port: bytes | str | int | None,
family: int = ...,
type: int = ...,
proto: int = ...,
flags: int = ...,
) -> list[tuple[AddressFamily, SocketKind, int, str, tuple[str, int] | tuple[str, int, int, int]]]: ...

View File

@@ -11,7 +11,7 @@ _PCTRTTT: TypeAlias = tuple[_PCTRTT, ...]
_PeerCertRetDictType: TypeAlias = dict[str, str | _PCTRTTT | _PCTRTT]
_PeerCertRetType: TypeAlias = _PeerCertRetDictType | bytes | None
_EnumRetType: TypeAlias = list[tuple[bytes, str, set[str] | bool]]
_PasswordType: TypeAlias = Union[Callable[[], str | bytes], str, bytes]
_PasswordType: TypeAlias = Union[Callable[[], str | bytes | bytearray], str, bytes, bytearray]
_SrvnmeCbType: TypeAlias = Callable[[SSLSocket | SSLObject, str | None, SSLSocket], int | None]
@@ -61,7 +61,7 @@ def create_default_context(
*,
cafile: StrOrBytesPath | None = ...,
capath: StrOrBytesPath | None = ...,
cadata: str | bytes | None = ...,
cadata: str | ReadableBuffer | None = ...,
) -> SSLContext: ...
def _create_unverified_context(
protocol: int = ...,
@@ -73,7 +73,7 @@ def _create_unverified_context(
keyfile: StrOrBytesPath | None = ...,
cafile: StrOrBytesPath | None = ...,
capath: StrOrBytesPath | None = ...,
cadata: str | bytes | None = ...,
cadata: str | ReadableBuffer | None = ...,
) -> SSLContext: ...
_create_default_https_context: Callable[..., SSLContext]
@@ -82,8 +82,11 @@ def RAND_bytes(__num: int) -> bytes: ...
def RAND_pseudo_bytes(__num: int) -> tuple[bytes, bool]: ...
def RAND_status() -> bool: ...
def RAND_egd(path: str) -> None: ...
def RAND_add(__s: bytes, __entropy: float) -> None: ...
def match_hostname(cert: _PeerCertRetType, hostname: str) -> None: ...
def RAND_add(__string: str | ReadableBuffer, __entropy: float) -> None: ...
if sys.version_info < (3, 12):
def match_hostname(cert: _PeerCertRetDictType, hostname: str) -> None: ...
def cert_time_to_seconds(cert_time: str) -> int: ...
if sys.version_info >= (3, 10):
@@ -94,7 +97,7 @@ if sys.version_info >= (3, 10):
else:
def get_server_certificate(addr: tuple[str, int], ssl_version: int = ..., ca_certs: str | None = ...) -> str: ...
def DER_cert_to_PEM_cert(der_cert_bytes: bytes) -> str: ...
def DER_cert_to_PEM_cert(der_cert_bytes: ReadableBuffer) -> str: ...
def PEM_cert_to_DER_cert(pem_cert_string: str) -> bytes: ...
class DefaultVerifyPaths(NamedTuple):
@@ -290,8 +293,8 @@ class SSLSocket(socket.socket):
@property
def session_reused(self) -> bool | None: ...
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def connect(self, addr: socket._Address | bytes) -> None: ...
def connect_ex(self, addr: socket._Address | bytes) -> int: ...
def connect(self, addr: socket._Address) -> None: ...
def connect_ex(self, addr: socket._Address) -> int: ...
def recv(self, buflen: int = ..., flags: int = ...) -> bytes: ...
def recv_into(self, buffer: WriteableBuffer, nbytes: int | None = ..., flags: int = ...) -> int: ...
def recvfrom(self, buflen: int = ..., flags: int = ...) -> tuple[bytes, socket._RetAddress]: ...
@@ -301,12 +304,12 @@ class SSLSocket(socket.socket):
def send(self, data: ReadableBuffer, flags: int = ...) -> int: ...
def sendall(self, data: ReadableBuffer, flags: int = ...) -> None: ...
@overload
def sendto(self, data: ReadableBuffer, flags_or_addr: socket._Address) -> int: ...
def sendto(self, data: ReadableBuffer, flags_or_addr: socket._Address, addr: None = ...) -> int: ...
@overload
def sendto(self, data: ReadableBuffer, flags_or_addr: int | socket._Address, addr: socket._Address | None = ...) -> int: ...
def sendto(self, data: ReadableBuffer, flags_or_addr: int, addr: socket._Address) -> int: ...
def shutdown(self, how: int) -> None: ...
def read(self, len: int = ..., buffer: bytearray | None = ...) -> bytes: ...
def write(self, data: bytes) -> int: ...
def write(self, data: ReadableBuffer) -> int: ...
def do_handshake(self, block: bool = ...) -> None: ... # block is undocumented
@overload
def getpeercert(self, binary_form: Literal[False] = ...) -> _PeerCertRetDictType | None: ...
@@ -362,7 +365,7 @@ class SSLContext:
) -> None: ...
def load_default_certs(self, purpose: Purpose = ...) -> None: ...
def load_verify_locations(
self, cafile: StrOrBytesPath | None = ..., capath: StrOrBytesPath | None = ..., cadata: str | bytes | None = ...
self, cafile: StrOrBytesPath | None = ..., capath: StrOrBytesPath | None = ..., cadata: str | ReadableBuffer | None = ...
) -> None: ...
@overload
def get_ca_certs(self, binary_form: Literal[False] = ...) -> list[_PeerCertRetDictType]: ...
@@ -408,7 +411,7 @@ class SSLObject:
def session_reused(self) -> bool: ...
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def read(self, len: int = ..., buffer: bytearray | None = ...) -> bytes: ...
def write(self, data: bytes) -> int: ...
def write(self, data: ReadableBuffer) -> int: ...
@overload
def getpeercert(self, binary_form: Literal[False] = ...) -> _PeerCertRetDictType | None: ...
@overload
@@ -433,16 +436,21 @@ class MemoryBIO:
pending: int
eof: bool
def read(self, __size: int = ...) -> bytes: ...
def write(self, __buf: bytes) -> int: ...
def write(self, __buf: ReadableBuffer) -> int: ...
def write_eof(self) -> None: ...
@final
class SSLSession:
id: bytes
time: int
timeout: int
ticket_lifetime_hint: int
has_ticket: bool
@property
def has_ticket(self) -> bool: ...
@property
def id(self) -> bytes: ...
@property
def ticket_lifetime_hint(self) -> int: ...
@property
def time(self) -> int: ...
@property
def timeout(self) -> int: ...
class SSLErrorNumber(enum.IntEnum):
SSL_ERROR_EOF: int