From 8d67718c0cf7b16927092f4e5a5fcb8dfb5f0678 Mon Sep 17 00:00:00 2001 From: Semyon Moroz Date: Tue, 25 Mar 2025 17:16:33 +0400 Subject: [PATCH] Improve `jwcrypto` (#13715) --- stubs/jwcrypto/jwcrypto/common.pyi | 12 +- stubs/jwcrypto/jwcrypto/jwe.pyi | 32 +++-- stubs/jwcrypto/jwcrypto/jwk.pyi | 215 ++++++++++++++++++++++------- stubs/jwcrypto/jwcrypto/jws.pyi | 42 ++++-- stubs/jwcrypto/jwcrypto/jwt.pyi | 11 +- 5 files changed, 219 insertions(+), 93 deletions(-) diff --git a/stubs/jwcrypto/jwcrypto/common.pyi b/stubs/jwcrypto/jwcrypto/common.pyi index 10ee692c3..b1361916a 100644 --- a/stubs/jwcrypto/jwcrypto/common.pyi +++ b/stubs/jwcrypto/jwcrypto/common.pyi @@ -1,7 +1,9 @@ -from _typeshed import Incomplete -from collections.abc import Iterator, MutableMapping +from collections.abc import Callable, Iterator, MutableMapping from typing import Any, NamedTuple +from jwcrypto.jwe import JWE +from jwcrypto.jws import JWS + def base64url_encode(payload: str | bytes) -> str: ... def base64url_decode(payload: str) -> bytes: ... def json_encode(string: str | bytes) -> str: ... @@ -36,11 +38,11 @@ class JWSEHeaderParameter(NamedTuple): description: str mustprotect: bool supported: bool - check_fn: Incomplete | None + check_fn: Callable[[JWS | JWE], bool] | None class JWSEHeaderRegistry(MutableMapping[str, JWSEHeaderParameter]): - def __init__(self, init_registry: Incomplete | None = None) -> None: ... - def check_header(self, h: str, value) -> bool: ... + def __init__(self, init_registry: dict[str, JWSEHeaderParameter] | None = None) -> None: ... + def check_header(self, h: str, value: JWS | JWE) -> bool: ... def __getitem__(self, key: str) -> JWSEHeaderParameter: ... def __iter__(self) -> Iterator[str]: ... def __delitem__(self, key: str) -> None: ... diff --git a/stubs/jwcrypto/jwcrypto/jwe.pyi b/stubs/jwcrypto/jwcrypto/jwe.pyi index 408b8850d..69d48ab2d 100644 --- a/stubs/jwcrypto/jwcrypto/jwe.pyi +++ b/stubs/jwcrypto/jwcrypto/jwe.pyi @@ -1,8 +1,10 @@ from _typeshed import Incomplete from collections.abc import Mapping, Sequence +from typing import Any +from typing_extensions import Self from jwcrypto import common -from jwcrypto.common import JWException, JWSEHeaderParameter +from jwcrypto.common import JWException, JWSEHeaderParameter, JWSEHeaderRegistry from jwcrypto.jwk import JWK, JWKSet default_max_compressed_size: int @@ -18,34 +20,34 @@ InvalidJWEKeyType = common.InvalidJWEKeyType InvalidJWEOperation = common.InvalidJWEOperation class JWE: - objects: Incomplete - plaintext: Incomplete - header_registry: Incomplete + objects: dict[str, Any] + plaintext: bytes | None + header_registry: JWSEHeaderRegistry cek: Incomplete - decryptlog: Incomplete + decryptlog: list[str] | None def __init__( self, - plaintext: bytes | None = None, + plaintext: str | bytes | None = None, protected: str | None = None, unprotected: str | None = None, aad: bytes | None = None, - algs: Incomplete | None = None, + algs: list[str] | None = None, recipient: str | None = None, - header: Incomplete | None = None, - header_registry: Incomplete | None = None, + header: str | None = None, + header_registry: Mapping[str, JWSEHeaderParameter] | None = None, ) -> None: ... @property - def allowed_algs(self): ... + def allowed_algs(self) -> list[str]: ... @allowed_algs.setter - def allowed_algs(self, algs) -> None: ... - def add_recipient(self, key, header: Incomplete | None = None) -> None: ... - def serialize(self, compact: bool = False): ... + def allowed_algs(self, algs: list[str]) -> None: ... + def add_recipient(self, key: JWK, header: dict[str, Any] | str | None = None) -> None: ... + def serialize(self, compact: bool = False) -> str: ... def decrypt(self, key: JWK | JWKSet) -> None: ... def deserialize(self, raw_jwe: str | bytes, key: JWK | JWKSet | None = None) -> None: ... @property - def payload(self): ... + def payload(self) -> bytes: ... @property def jose_header(self) -> dict[Incomplete, Incomplete]: ... @classmethod - def from_jose_token(cls, token: str | bytes) -> JWE: ... + def from_jose_token(cls, token: str | bytes) -> Self: ... def __eq__(self, other: object) -> bool: ... diff --git a/stubs/jwcrypto/jwcrypto/jwk.pyi b/stubs/jwcrypto/jwcrypto/jwk.pyi index f5638cddf..9b3d912c3 100644 --- a/stubs/jwcrypto/jwcrypto/jwk.pyi +++ b/stubs/jwcrypto/jwcrypto/jwk.pyi @@ -1,8 +1,10 @@ -from _typeshed import Incomplete -from collections.abc import Sequence +from collections.abc import Callable, Sequence from enum import Enum -from typing import Any, NamedTuple +from typing import Any, Literal, NamedTuple, TypeVar, overload +from typing_extensions import Self, deprecated +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec, rsa from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey as Ed448PrivateKey, Ed448PublicKey as Ed448PublicKey from cryptography.hazmat.primitives.asymmetric.ed25519 import ( Ed25519PrivateKey as Ed25519PrivateKey, @@ -15,6 +17,8 @@ from cryptography.hazmat.primitives.asymmetric.x25519 import ( ) from jwcrypto.common import JWException +_T = TypeVar("_T") + class UnimplementedOKPCurveKey: @classmethod def generate(cls) -> None: ... @@ -24,9 +28,25 @@ class UnimplementedOKPCurveKey: def from_private_bytes(cls, *args) -> None: ... ImplementedOkpCurves: Sequence[str] -priv_bytes: Incomplete +priv_bytes: Callable[[bytes], X25519PrivateKey] | None -JWKTypesRegistry: Incomplete +class _Ed25519_CURVE(NamedTuple): + pubkey: UnimplementedOKPCurveKey + privkey: UnimplementedOKPCurveKey + +class _Ed448_CURVE(NamedTuple): + pubkey: UnimplementedOKPCurveKey + privkey: UnimplementedOKPCurveKey + +class _X25519_CURVE(NamedTuple): + pubkey: UnimplementedOKPCurveKey + privkey: UnimplementedOKPCurveKey + +class _X448_CURVE(NamedTuple): + pubkey: UnimplementedOKPCurveKey + privkey: UnimplementedOKPCurveKey + +JWKTypesRegistry: dict[str, str] class ParmType(Enum): name = "A string with a name" # pyright: ignore[reportAssignmentType] @@ -35,48 +55,80 @@ class ParmType(Enum): unsupported = "Unsupported Parameter" class JWKParameter(NamedTuple): - description: Incomplete - public: Incomplete - required: Incomplete - type: Incomplete + description: str + public: bool + required: bool | None + type: ParmType | None -JWKValuesRegistry: Incomplete -JWKParamsRegistry: Incomplete -JWKEllipticCurveRegistry: Incomplete -JWKUseRegistry: Incomplete -JWKOperationsRegistry: Incomplete -JWKpycaCurveMap: Incomplete -IANANamedInformationHashAlgorithmRegistry: Incomplete +JWKValuesRegistry: dict[str, dict[str, JWKParameter]] +JWKParamsRegistry: dict[str, JWKParameter] +JWKEllipticCurveRegistry: dict[str, str] +JWKUseRegistry: dict[str, str] +JWKOperationsRegistry: dict[str, str] +JWKpycaCurveMap: dict[str, str] +IANANamedInformationHashAlgorithmRegistry: dict[ + str, + hashes.SHA256 + | hashes.SHA384 + | hashes.SHA512 + | hashes.SHA3_224 + | hashes.SHA3_256 + | hashes.SHA3_384 + | hashes.SHA3_512 + | hashes.BLAKE2s + | hashes.BLAKE2b + | None, +] class InvalidJWKType(JWException): - value: Incomplete - def __init__(self, value: Incomplete | None = None) -> None: ... + value: str | None + def __init__(self, value: str | None = None) -> None: ... class InvalidJWKUsage(JWException): - value: Incomplete - use: Incomplete - def __init__(self, use, value) -> None: ... + value: str + use: str + def __init__(self, use: str, value: str) -> None: ... class InvalidJWKOperation(JWException): - op: Incomplete - values: Incomplete - def __init__(self, operation, values) -> None: ... + op: str + values: Sequence[str] + def __init__(self, operation: str, values: Sequence[str]) -> None: ... class InvalidJWKValue(JWException): ... class JWK(dict[str, Any]): def __init__(self, **kwargs) -> None: ... @classmethod - def generate(cls, **kwargs): ... + def generate(cls, **kwargs) -> Self: ... def generate_key(self, **params) -> None: ... def import_key(self, **kwargs) -> None: ... @classmethod - def from_json(cls, key): ... - def export(self, private_key: bool = True, as_dict: bool = False): ... - def export_public(self, as_dict: bool = False): ... - def export_private(self, as_dict: bool = False): ... - def export_symmetric(self, as_dict: bool = False): ... - def public(self): ... + def from_json(cls, key) -> Self: ... + @overload + def export(self, private_key: bool = True, as_dict: Literal[False] = False) -> str: ... + @overload + def export(self, private_key: bool, as_dict: Literal[True]) -> dict[str, Any]: ... + @overload + def export(self, *, as_dict: Literal[True]) -> dict[str, Any]: ... + @overload + def export_public(self, as_dict: Literal[False] = False) -> str: ... + @overload + def export_public(self, as_dict: Literal[True]) -> dict[str, Any]: ... + @overload + def export_public(self, as_dict: bool = False) -> str | dict[str, Any]: ... + @overload + def export_private(self, as_dict: Literal[False] = False) -> str: ... + @overload + def export_private(self, as_dict: Literal[True]) -> dict[str, Any]: ... + @overload + def export_private(self, as_dict: bool = False) -> str | dict[str, Any]: ... + @overload + def export_symmetric(self, as_dict: Literal[False] = False) -> str: ... + @overload + def export_symmetric(self, as_dict: Literal[True]) -> dict[str, Any]: ... + @overload + def export_symmetric(self, as_dict: bool = False) -> str | dict[str, Any]: ... + def public(self) -> Self: ... @property def has_public(self) -> bool: ... @property @@ -84,32 +136,89 @@ class JWK(dict[str, Any]): @property def is_symmetric(self) -> bool: ... @property - def key_type(self): ... + @deprecated("") + def key_type(self) -> str | None: ... @property - def key_id(self): ... + @deprecated("") + def key_id(self) -> str | None: ... @property - def key_curve(self): ... - def get_curve(self, arg): ... - def get_op_key(self, operation: Incomplete | None = None, arg: Incomplete | None = None): ... - def import_from_pyca(self, key) -> None: ... - def import_from_pem(self, data, password: Incomplete | None = None, kid: Incomplete | None = None) -> None: ... - def export_to_pem(self, private_key: bool = False, password: bool = False): ... + @deprecated("") + def key_curve(self) -> str | None: ... + @deprecated("") + def get_curve( + self, arg: str + ) -> ( + ec.SECP256R1 + | ec.SECP384R1 + | ec.SECP521R1 + | ec.SECP256K1 + | ec.BrainpoolP256R1 + | ec.BrainpoolP384R1 + | ec.BrainpoolP512R1 + | _Ed25519_CURVE + | _Ed448_CURVE + | _X25519_CURVE + | _X448_CURVE + ): ... + def get_op_key( + self, operation: str | None = None, arg: str | None = None + ) -> str | rsa.RSAPrivateKey | rsa.RSAPublicKey | ec.EllipticCurvePrivateKey | ec.EllipticCurvePublicKey | None: ... + def import_from_pyca( + self, + key: ( + rsa.RSAPrivateKey + | rsa.RSAPublicKey + | ec.EllipticCurvePrivateKey + | ec.EllipticCurvePublicKey + | Ed25519PrivateKey + | Ed448PrivateKey + | X25519PrivateKey + | Ed25519PublicKey + | Ed448PublicKey + | X25519PublicKey + ), + ) -> None: ... + def import_from_pem(self, data: bytes, password: bytes | None = None, kid: str | None = None) -> None: ... + def export_to_pem(self, private_key: bool = False, password: bool = False) -> bytes: ... @classmethod - def from_pyca(cls, key): ... + def from_pyca( + cls, + key: ( + rsa.RSAPrivateKey + | rsa.RSAPublicKey + | ec.EllipticCurvePrivateKey + | ec.EllipticCurvePublicKey + | Ed25519PrivateKey + | Ed448PrivateKey + | X25519PrivateKey + | Ed25519PublicKey + | Ed448PublicKey + | X25519PublicKey + ), + ) -> Self: ... @classmethod - def from_pem(cls, data, password: Incomplete | None = None): ... - def thumbprint(self, hashalg=...): ... - def thumbprint_uri(self, hname: str = "sha-256"): ... + def from_pem(cls, data: bytes, password: bytes | None = None) -> Self: ... + def thumbprint(self, hashalg: hashes.HashAlgorithm = ...) -> str: ... + def thumbprint_uri(self, hname: str = "sha-256") -> str: ... @classmethod - def from_password(cls, password): ... - def setdefault(self, key: str, default: Incomplete | None = None): ... + def from_password(cls, password: str) -> Self: ... + def setdefault(self, key: str, default: _T | None = None) -> _T: ... -class JWKSet(dict[str, Any]): - def add(self, elem) -> None: ... - def export(self, private_keys: bool = True, as_dict: bool = False): ... - def import_keyset(self, keyset) -> None: ... +class JWKSet(dict[Literal["keys"], set[JWK]]): + @overload + def __setitem__(self, key: Literal["keys"], val: JWK) -> None: ... + @overload + def __setitem__(self, key: str, val: Any) -> None: ... + def add(self, elem: JWK) -> None: ... + @overload + def export(self, private_keys: bool = True, as_dict: Literal[False] = False) -> str: ... + @overload + def export(self, private_keys: bool, as_dict: Literal[True]) -> dict[str, Any]: ... + @overload + def export(self, *, as_dict: Literal[True]) -> dict[str, Any]: ... + def import_keyset(self, keyset: str | bytes) -> None: ... @classmethod - def from_json(cls, keyset): ... - def get_key(self, kid): ... - def get_keys(self, kid): ... - def setdefault(self, key: str, default: Incomplete | None = None): ... + def from_json(cls, keyset: str | bytes) -> Self: ... + def get_key(self, kid: str) -> JWK | None: ... + def get_keys(self, kid: str) -> set[JWK]: ... + def setdefault(self, key: str, default: _T | None = None) -> _T: ... diff --git a/stubs/jwcrypto/jwcrypto/jws.pyi b/stubs/jwcrypto/jwcrypto/jws.pyi index d1bb2f8a2..f5d3fab33 100644 --- a/stubs/jwcrypto/jwcrypto/jws.pyi +++ b/stubs/jwcrypto/jwcrypto/jws.pyi @@ -1,9 +1,14 @@ from _typeshed import Incomplete +from collections.abc import Mapping +from typing import Any, Literal +from typing_extensions import Self -from jwcrypto.common import JWException +from jwcrypto.common import JWException, JWSEHeaderParameter +from jwcrypto.jwa import JWAAlgorithm +from jwcrypto.jwk import JWK, JWKSet -JWSHeaderRegistry: Incomplete -default_allowed_algs: Incomplete +JWSHeaderRegistry: Mapping[str, JWSEHeaderParameter] +default_allowed_algs: list[str] class InvalidJWSSignature(JWException): def __init__(self, message: str | None = None, exception: BaseException | None = None) -> None: ... @@ -15,19 +20,26 @@ class InvalidJWSOperation(JWException): def __init__(self, message: str | None = None, exception: BaseException | None = None) -> None: ... class JWSCore: - alg: Incomplete - engine: Incomplete - key: Incomplete - header: Incomplete - protected: Incomplete - payload: Incomplete - def __init__(self, alg, key, header, payload, algs: Incomplete | None = None) -> None: ... - def sign(self): ... - def verify(self, signature): ... + alg: str + engine: JWAAlgorithm + key: JWK | JWKSet + header: dict[str, Any] + protected: str + payload: bytes + def __init__( + self, + alg: str, + key: JWK | JWKSet, + header: dict[str, Any] | str | None, + payload: str | bytes, + algs: list[str] | None = None, + ) -> None: ... + def sign(self) -> dict[str, str | bytes]: ... + def verify(self, signature: bytes) -> Literal[True]: ... class JWS: objects: Incomplete - verifylog: Incomplete + verifylog: list[str] | None header_registry: Incomplete def __init__(self, payload: Incomplete | None = None, header_registry: Incomplete | None = None) -> None: ... @property @@ -41,12 +53,12 @@ class JWS: def add_signature( self, key, alg: Incomplete | None = None, protected: Incomplete | None = None, header: Incomplete | None = None ) -> None: ... - def serialize(self, compact: bool = False): ... + def serialize(self, compact: bool = False) -> str: ... @property def payload(self): ... def detach_payload(self) -> None: ... @property def jose_header(self): ... @classmethod - def from_jose_token(cls, token): ... + def from_jose_token(cls, token: str | bytes) -> Self: ... def __eq__(self, other: object) -> bool: ... diff --git a/stubs/jwcrypto/jwcrypto/jwt.pyi b/stubs/jwcrypto/jwcrypto/jwt.pyi index 4bd2a5b79..0ede9c5c6 100644 --- a/stubs/jwcrypto/jwcrypto/jwt.pyi +++ b/stubs/jwcrypto/jwcrypto/jwt.pyi @@ -1,5 +1,6 @@ from _typeshed import Incomplete from collections.abc import Mapping +from typing import Any from typing_extensions import deprecated from jwcrypto.common import JWException, JWKeyNotFound @@ -31,11 +32,11 @@ class JWTMissingKey(JWKeyNotFound): def __init__(self, message: str | None = None, exception: BaseException | None = None) -> None: ... class JWT: - deserializelog: Incomplete + deserializelog: list[str] | None def __init__( self, - header: dict[Incomplete, Incomplete] | str | None = None, - claims: dict[Incomplete, Incomplete] | str | None = None, + header: dict[str, Any] | str | None = None, + claims: dict[str, Any] | str | None = None, jwt: Incomplete | None = None, key: JWK | JWKSet | None = None, algs: Incomplete | None = None, @@ -44,9 +45,9 @@ class JWT: expected_type: Incomplete | None = None, ) -> None: ... @property - def header(self): ... + def header(self) -> str: ... @header.setter - def header(self, h) -> None: ... + def header(self, h: dict[str, Any] | str) -> None: ... @property def claims(self): ... @claims.setter