From 4f37d8fff818b0663b3f1cabb57948b6950a78fd Mon Sep 17 00:00:00 2001 From: Stephen Morton Date: Tue, 1 Oct 2024 20:10:51 -0700 Subject: [PATCH] add _ssl module (#11155) Really all I needed for fixing the inheritance was _ssl._SSLContext. But then I needed all the other stuff in _ssl, and if I was doing that I wanted to do a thorough job of it. Motivation was originally related to https://github.com/python/typeshed/issues/3968 , but we're well beyond that now, really. Co-authored-by: Jelle Zijlstra --- pyproject.toml | 1 + stdlib/@tests/stubtest_allowlists/common.txt | 1 - stdlib/@tests/stubtest_allowlists/py38.txt | 4 + stdlib/@tests/stubtest_allowlists/py39.txt | 4 + stdlib/VERSIONS | 1 + stdlib/_ssl.pyi | 292 +++++++++++++++++++ stdlib/ssl.pyi | 102 +++---- 7 files changed, 342 insertions(+), 63 deletions(-) create mode 100644 stdlib/_ssl.pyi diff --git a/pyproject.toml b/pyproject.toml index 53f08ea1c..1f2e6d6b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,6 +146,7 @@ extra-standard-library = [ "_random", "_sitebuiltins", "_socket", + "_ssl", "_stat", "_thread", "_threading_local", diff --git a/stdlib/@tests/stubtest_allowlists/common.txt b/stdlib/@tests/stubtest_allowlists/common.txt index 942af553b..2b4cc0b1c 100644 --- a/stdlib/@tests/stubtest_allowlists/common.txt +++ b/stdlib/@tests/stubtest_allowlists/common.txt @@ -413,7 +413,6 @@ queue.SimpleQueue.__init__ imaplib.IMAP4_SSL.ssl ssl.PROTOCOL_SSLv2 ssl.PROTOCOL_SSLv3 -ssl.RAND_egd pickle._Pickler\..* # Best effort typing for undocumented internals pickle._Unpickler\..* # Best effort typing for undocumented internals diff --git a/stdlib/@tests/stubtest_allowlists/py38.txt b/stdlib/@tests/stubtest_allowlists/py38.txt index e66b49971..8a8cebab1 100644 --- a/stdlib/@tests/stubtest_allowlists/py38.txt +++ b/stdlib/@tests/stubtest_allowlists/py38.txt @@ -246,6 +246,10 @@ unittest\.test\..+ pstats.SortKey.__new__ tkinter.EventType.__new__ +# Items that depend on the existence and flags of SSL +ssl.RAND_egd +_ssl.RAND_egd + # Incorrectly star import. ctypes._endian.DEFAULT_MODE ctypes._endian.RTLD_GLOBAL diff --git a/stdlib/@tests/stubtest_allowlists/py39.txt b/stdlib/@tests/stubtest_allowlists/py39.txt index e7217a040..9c646178b 100644 --- a/stdlib/@tests/stubtest_allowlists/py39.txt +++ b/stdlib/@tests/stubtest_allowlists/py39.txt @@ -229,6 +229,10 @@ types.SimpleNamespace.__init__ # class doesn't accept positional arguments but pstats.SortKey.__new__ tkinter.EventType.__new__ +# Items that depend on the existence and flags of SSL +ssl.RAND_egd +_ssl.RAND_egd + # Incorrectly star import. ctypes._endian.DEFAULT_MODE ctypes._endian.RTLD_GLOBAL diff --git a/stdlib/VERSIONS b/stdlib/VERSIONS index dfed62f69..9f9b9bd10 100644 --- a/stdlib/VERSIONS +++ b/stdlib/VERSIONS @@ -50,6 +50,7 @@ _pydecimal: 3.5- _random: 3.0- _sitebuiltins: 3.4- _socket: 3.0- # present in 3.0 at runtime, but not in typeshed +_ssl: 3.0- _stat: 3.4- _thread: 3.0- _threading_local: 3.0- diff --git a/stdlib/_ssl.pyi b/stdlib/_ssl.pyi new file mode 100644 index 000000000..3e8887414 --- /dev/null +++ b/stdlib/_ssl.pyi @@ -0,0 +1,292 @@ +import sys +from _typeshed import ReadableBuffer, StrOrBytesPath +from collections.abc import Callable +from ssl import ( + SSLCertVerificationError as SSLCertVerificationError, + SSLContext, + SSLEOFError as SSLEOFError, + SSLError as SSLError, + SSLObject, + SSLSyscallError as SSLSyscallError, + SSLWantReadError as SSLWantReadError, + SSLWantWriteError as SSLWantWriteError, + SSLZeroReturnError as SSLZeroReturnError, +) +from typing import Any, Literal, TypedDict, final, overload +from typing_extensions import NotRequired, Self, TypeAlias + +_PasswordType: TypeAlias = Callable[[], str | bytes | bytearray] | str | bytes | bytearray +_PCTRTT: TypeAlias = tuple[tuple[str, str], ...] +_PCTRTTT: TypeAlias = tuple[_PCTRTT, ...] +_PeerCertRetDictType: TypeAlias = dict[str, str | _PCTRTTT | _PCTRTT] + +class _Cipher(TypedDict): + aead: bool + alg_bits: int + auth: str + description: str + digest: str | None + id: int + kea: str + name: str + protocol: str + strength_bits: int + symmetric: str + +class _CertInfo(TypedDict): + subject: tuple[tuple[tuple[str, str], ...], ...] + issuer: tuple[tuple[tuple[str, str], ...], ...] + version: int + serialNumber: str + notBefore: str + notAfter: str + subjectAltName: NotRequired[tuple[tuple[str, str], ...] | None] + OCSP: NotRequired[tuple[str, ...] | None] + caIssuers: NotRequired[tuple[str, ...] | None] + crlDistributionPoints: NotRequired[tuple[str, ...] | None] + +def RAND_add(string: str | ReadableBuffer, entropy: float, /) -> None: ... +def RAND_bytes(n: int, /) -> bytes: ... + +if sys.version_info < (3, 12): + def RAND_pseudo_bytes(n: int, /) -> tuple[bytes, bool]: ... + +if sys.version_info < (3, 10): + def RAND_egd(path: str) -> None: ... + +def RAND_status() -> bool: ... +def get_default_verify_paths() -> tuple[str, str, str, str]: ... + +if sys.platform == "win32": + _EnumRetType: TypeAlias = list[tuple[bytes, str, set[str] | bool]] + def enum_certificates(store_name: str) -> _EnumRetType: ... + def enum_crls(store_name: str) -> _EnumRetType: ... + +def txt2obj(txt: str, name: bool = False) -> tuple[int, str, str, str]: ... +def nid2obj(nid: int, /) -> tuple[int, str, str, str]: ... + +class _SSLContext: + check_hostname: bool + keylog_filename: str | None + maximum_version: int + minimum_version: int + num_tickets: int + options: int + post_handshake_auth: bool + protocol: int + if sys.version_info >= (3, 10): + security_level: int + sni_callback: Callable[[SSLObject, str, SSLContext], None | int] | None + verify_flags: int + verify_mode: int + def __new__(cls, protocol: int, /) -> Self: ... + def cert_store_stats(self) -> dict[str, int]: ... + @overload + def get_ca_certs(self, binary_form: Literal[False] = False) -> list[_PeerCertRetDictType]: ... + @overload + def get_ca_certs(self, binary_form: Literal[True]) -> list[bytes]: ... + @overload + def get_ca_certs(self, binary_form: bool = False) -> Any: ... + def get_ciphers(self) -> list[_Cipher]: ... + def load_cert_chain( + self, certfile: StrOrBytesPath, keyfile: StrOrBytesPath | None = None, password: _PasswordType | None = None + ) -> None: ... + def load_dh_params(self, path: str, /) -> None: ... + def load_verify_locations( + self, + cafile: StrOrBytesPath | None = None, + capath: StrOrBytesPath | None = None, + cadata: str | ReadableBuffer | None = None, + ) -> None: ... + def session_stats(self) -> dict[str, int]: ... + def set_ciphers(self, cipherlist: str, /) -> None: ... + def set_default_verify_paths(self) -> None: ... + def set_ecdh_curve(self, name: str, /) -> None: ... + if sys.version_info >= (3, 13): + def set_psk_client_callback(self, callback: Callable[[str | None], tuple[str | None, bytes]] | None) -> None: ... + def set_psk_server_callback( + self, callback: Callable[[str | None], tuple[str | None, bytes]] | None, identity_hint: str | None = None + ) -> None: ... + +@final +class MemoryBIO: + eof: bool + pending: int + def __new__(self) -> Self: ... + def read(self, size: int = -1, /) -> bytes: ... + def write(self, b: ReadableBuffer, /) -> int: ... + def write_eof(self) -> None: ... + +@final +class SSLSession: + @property + def has_ticket(self) -> bool: ... + @property + def id(self) -> bytes: ... + @property + def ticket_lifetime_hint(self) -> int: ... + @property + def time(self) -> int: ... + @property + def timeout(self) -> int: ... + +# _ssl.Certificate is weird: it can't be instantiated or subclassed. +# Instances can only be created via methods of the private _ssl._SSLSocket class, +# for which the relevant method signatures are: +# +# class _SSLSocket: +# def get_unverified_chain(self) -> list[Certificate] | None: ... +# def get_verified_chain(self) -> list[Certificate] | None: ... +# +# You can find a _ssl._SSLSocket object as the _sslobj attribute of a ssl.SSLSocket object + +if sys.version_info >= (3, 10): + @final + class Certificate: + def get_info(self) -> _CertInfo: ... + @overload + def public_bytes(self) -> str: ... + @overload + def public_bytes(self, format: Literal[1] = 1, /) -> str: ... # ENCODING_PEM + @overload + def public_bytes(self, format: Literal[2], /) -> bytes: ... # ENCODING_DER + @overload + def public_bytes(self, format: int, /) -> str | bytes: ... + +if sys.version_info < (3, 12): + err_codes_to_names: dict[tuple[int, int], str] + err_names_to_codes: dict[str, tuple[int, int]] + lib_codes_to_names: dict[int, str] + +_DEFAULT_CIPHERS: str + +# SSL error numbers +SSL_ERROR_ZERO_RETURN: int +SSL_ERROR_WANT_READ: int +SSL_ERROR_WANT_WRITE: int +SSL_ERROR_WANT_X509_LOOKUP: int +SSL_ERROR_SYSCALL: int +SSL_ERROR_SSL: int +SSL_ERROR_WANT_CONNECT: int +SSL_ERROR_EOF: int +SSL_ERROR_INVALID_ERROR_CODE: int + +# verify modes +CERT_NONE: int +CERT_OPTIONAL: int +CERT_REQUIRED: int + +# verify flags +VERIFY_DEFAULT: int +VERIFY_CRL_CHECK_LEAF: int +VERIFY_CRL_CHECK_CHAIN: int +VERIFY_X509_STRICT: int +VERIFY_X509_TRUSTED_FIRST: int +if sys.version_info >= (3, 10): + VERIFY_ALLOW_PROXY_CERTS: int + VERIFY_X509_PARTIAL_CHAIN: int + +# alert descriptions +ALERT_DESCRIPTION_CLOSE_NOTIFY: int +ALERT_DESCRIPTION_UNEXPECTED_MESSAGE: int +ALERT_DESCRIPTION_BAD_RECORD_MAC: int +ALERT_DESCRIPTION_RECORD_OVERFLOW: int +ALERT_DESCRIPTION_DECOMPRESSION_FAILURE: int +ALERT_DESCRIPTION_HANDSHAKE_FAILURE: int +ALERT_DESCRIPTION_BAD_CERTIFICATE: int +ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE: int +ALERT_DESCRIPTION_CERTIFICATE_REVOKED: int +ALERT_DESCRIPTION_CERTIFICATE_EXPIRED: int +ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN: int +ALERT_DESCRIPTION_ILLEGAL_PARAMETER: int +ALERT_DESCRIPTION_UNKNOWN_CA: int +ALERT_DESCRIPTION_ACCESS_DENIED: int +ALERT_DESCRIPTION_DECODE_ERROR: int +ALERT_DESCRIPTION_DECRYPT_ERROR: int +ALERT_DESCRIPTION_PROTOCOL_VERSION: int +ALERT_DESCRIPTION_INSUFFICIENT_SECURITY: int +ALERT_DESCRIPTION_INTERNAL_ERROR: int +ALERT_DESCRIPTION_USER_CANCELLED: int +ALERT_DESCRIPTION_NO_RENEGOTIATION: int +ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION: int +ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE: int +ALERT_DESCRIPTION_UNRECOGNIZED_NAME: int +ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE: int +ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE: int +ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY: int + +# protocol versions +PROTOCOL_SSLv23: int +PROTOCOL_TLS: int +PROTOCOL_TLS_CLIENT: int +PROTOCOL_TLS_SERVER: int +PROTOCOL_TLSv1: int +PROTOCOL_TLSv1_1: int +PROTOCOL_TLSv1_2: int + +# protocol options +OP_ALL: int +OP_NO_SSLv2: int +OP_NO_SSLv3: int +OP_NO_TLSv1: int +OP_NO_TLSv1_1: int +OP_NO_TLSv1_2: int +OP_NO_TLSv1_3: int +OP_CIPHER_SERVER_PREFERENCE: int +OP_SINGLE_DH_USE: int +OP_NO_TICKET: int +OP_SINGLE_ECDH_USE: int +OP_NO_COMPRESSION: int +OP_ENABLE_MIDDLEBOX_COMPAT: int +OP_NO_RENEGOTIATION: int +if sys.version_info >= (3, 11): + OP_IGNORE_UNEXPECTED_EOF: int +elif sys.version_info >= (3, 8) and sys.platform == "linux": + OP_IGNORE_UNEXPECTED_EOF: int +if sys.version_info >= (3, 12): + OP_LEGACY_SERVER_CONNECT: int + OP_ENABLE_KTLS: int + +# host flags +HOSTFLAG_ALWAYS_CHECK_SUBJECT: int +HOSTFLAG_NEVER_CHECK_SUBJECT: int +HOSTFLAG_NO_WILDCARDS: int +HOSTFLAG_NO_PARTIAL_WILDCARDS: int +HOSTFLAG_MULTI_LABEL_WILDCARDS: int +HOSTFLAG_SINGLE_LABEL_SUBDOMAINS: int + +if sys.version_info >= (3, 10): + # certificate file types + # Typed as Literal so the overload on Certificate.public_bytes can work properly. + ENCODING_PEM: Literal[1] + ENCODING_DER: Literal[2] + +# protocol versions +PROTO_MINIMUM_SUPPORTED: int +PROTO_MAXIMUM_SUPPORTED: int +PROTO_SSLv3: int +PROTO_TLSv1: int +PROTO_TLSv1_1: int +PROTO_TLSv1_2: int +PROTO_TLSv1_3: int + +# feature support +HAS_SNI: bool +HAS_TLS_UNIQUE: bool +HAS_ECDH: bool +HAS_NPN: bool +if sys.version_info >= (3, 13): + HAS_PSK: bool +HAS_ALPN: bool +HAS_SSLv2: bool +HAS_SSLv3: bool +HAS_TLSv1: bool +HAS_TLSv1_1: bool +HAS_TLSv1_2: bool +HAS_TLSv1_3: bool + +# version info +OPENSSL_VERSION_NUMBER: int +OPENSSL_VERSION_INFO: tuple[int, int, int, int, int] +OPENSSL_VERSION: str +_OPENSSL_API_VERSION: tuple[int, int, int, int, int] diff --git a/stdlib/ssl.pyi b/stdlib/ssl.pyi index 81c68c69e..1d97c02ac 100644 --- a/stdlib/ssl.pyi +++ b/stdlib/ssl.pyi @@ -1,18 +1,51 @@ import enum import socket import sys +from _ssl import ( + _DEFAULT_CIPHERS as _DEFAULT_CIPHERS, + _OPENSSL_API_VERSION as _OPENSSL_API_VERSION, + HAS_ALPN as HAS_ALPN, + HAS_ECDH as HAS_ECDH, + HAS_NPN as HAS_NPN, + HAS_SNI as HAS_SNI, + OPENSSL_VERSION as OPENSSL_VERSION, + OPENSSL_VERSION_INFO as OPENSSL_VERSION_INFO, + OPENSSL_VERSION_NUMBER as OPENSSL_VERSION_NUMBER, + HAS_SSLv2 as HAS_SSLv2, + HAS_SSLv3 as HAS_SSLv3, + HAS_TLSv1 as HAS_TLSv1, + HAS_TLSv1_1 as HAS_TLSv1_1, + HAS_TLSv1_2 as HAS_TLSv1_2, + HAS_TLSv1_3 as HAS_TLSv1_3, + MemoryBIO as MemoryBIO, + RAND_add as RAND_add, + RAND_bytes as RAND_bytes, + RAND_status as RAND_status, + SSLSession as SSLSession, + _PasswordType as _PasswordType, # typeshed only, but re-export for other type stubs to use + _SSLContext, +) from _typeshed import ReadableBuffer, StrOrBytesPath, WriteableBuffer from collections.abc import Callable, Iterable -from typing import Any, Literal, NamedTuple, TypedDict, final, overload +from typing import Any, Literal, NamedTuple, TypedDict, overload from typing_extensions import Never, Self, TypeAlias +if sys.version_info >= (3, 13): + from _ssl import HAS_PSK as HAS_PSK + +if sys.version_info < (3, 12): + from _ssl import RAND_pseudo_bytes as RAND_pseudo_bytes + +if sys.version_info < (3, 10): + from _ssl import RAND_egd as RAND_egd + +if sys.platform == "win32": + from _ssl import enum_certificates as enum_certificates, enum_crls as enum_crls + _PCTRTT: TypeAlias = tuple[tuple[str, str], ...] _PCTRTTT: TypeAlias = tuple[_PCTRTT, ...] _PeerCertRetDictType: TypeAlias = dict[str, str | _PCTRTTT | _PCTRTT] _PeerCertRetType: TypeAlias = _PeerCertRetDictType | bytes | None -_EnumRetType: TypeAlias = list[tuple[bytes, str, set[str] | bool]] -_PasswordType: TypeAlias = Callable[[], str | bytes | bytearray] | str | bytes | bytearray - _SrvnmeCbType: TypeAlias = Callable[[SSLSocket | SSLObject, str | None, SSLSocket], int | None] socket_error = OSError @@ -98,15 +131,6 @@ else: _create_default_https_context: Callable[..., SSLContext] -def RAND_bytes(n: int, /) -> bytes: ... - -if sys.version_info < (3, 12): - def RAND_pseudo_bytes(n: int, /) -> tuple[bytes, bool]: ... - -def RAND_status() -> bool: ... -def RAND_egd(path: str) -> None: ... -def RAND_add(string: str | ReadableBuffer, entropy: float, /) -> None: ... - if sys.version_info < (3, 12): def match_hostname(cert: _PeerCertRetDictType, hostname: str) -> None: ... @@ -133,10 +157,6 @@ class DefaultVerifyPaths(NamedTuple): def get_default_verify_paths() -> DefaultVerifyPaths: ... -if sys.platform == "win32": - def enum_certificates(store_name: str) -> _EnumRetType: ... - def enum_crls(store_name: str) -> _EnumRetType: ... - class VerifyMode(enum.IntEnum): CERT_NONE = 0 CERT_OPTIONAL = 1 @@ -229,21 +249,8 @@ if sys.version_info >= (3, 11) or sys.platform == "linux": OP_IGNORE_UNEXPECTED_EOF: Options HAS_NEVER_CHECK_COMMON_NAME: bool -HAS_SSLv2: bool -HAS_SSLv3: bool -HAS_TLSv1: bool -HAS_TLSv1_1: bool -HAS_TLSv1_2: bool -HAS_TLSv1_3: bool -HAS_ALPN: bool -HAS_ECDH: bool -HAS_SNI: bool -HAS_NPN: bool -CHANNEL_BINDING_TYPES: list[str] -OPENSSL_VERSION: str -OPENSSL_VERSION_INFO: tuple[int, int, int, int, int] -OPENSSL_VERSION_NUMBER: int +CHANNEL_BINDING_TYPES: list[str] class AlertDescription(enum.IntEnum): ALERT_DESCRIPTION_ACCESS_DENIED = 49 @@ -379,17 +386,15 @@ class TLSVersion(enum.IntEnum): TLSv1_2 = 771 TLSv1_3 = 772 -class SSLContext: - check_hostname: bool +class SSLContext(_SSLContext): options: Options verify_flags: VerifyFlags verify_mode: VerifyMode @property - def protocol(self) -> _SSLMethod: ... + def protocol(self) -> _SSLMethod: ... # type: ignore[override] hostname_checks_common_name: bool maximum_version: TLSVersion minimum_version: TLSVersion - sni_callback: Callable[[SSLObject, str, SSLContext], None | int] | None # The following two attributes have class-level defaults. # However, the docs explicitly state that it's OK to override these attributes on instances, # so making these ClassVars wouldn't be appropriate @@ -406,10 +411,6 @@ class SSLContext: else: def __new__(cls, protocol: int = ..., *args: Any, **kwargs: Any) -> Self: ... - def cert_store_stats(self) -> dict[str, int]: ... - def load_cert_chain( - self, certfile: StrOrBytesPath, keyfile: StrOrBytesPath | None = None, password: _PasswordType | None = None - ) -> None: ... def load_default_certs(self, purpose: Purpose = ...) -> None: ... def load_verify_locations( self, @@ -448,7 +449,6 @@ class SSLContext: server_hostname: str | bytes | None = None, session: SSLSession | None = None, ) -> SSLObject: ... - def session_stats(self) -> dict[str, int]: ... class SSLObject: context: SSLContext @@ -483,28 +483,6 @@ class SSLObject: def get_verified_chain(self) -> list[bytes]: ... def get_unverified_chain(self) -> list[bytes]: ... -@final -class MemoryBIO: - pending: int - eof: bool - def read(self, size: int = -1, /) -> bytes: ... - def write(self, b: ReadableBuffer, /) -> int: ... - def write_eof(self) -> None: ... - -@final -class SSLSession: - @property - def has_ticket(self) -> bool: ... - @property - def id(self) -> bytes: ... - @property - def ticket_lifetime_hint(self) -> int: ... - @property - def time(self) -> int: ... - @property - def timeout(self) -> int: ... - def __eq__(self, value: object, /) -> bool: ... - class SSLErrorNumber(enum.IntEnum): SSL_ERROR_EOF = 8 SSL_ERROR_INVALID_ERROR_CODE = 10