From 648ed7bc4eab517f8f1d2f8b0efad75b424a1548 Mon Sep 17 00:00:00 2001 From: Justine Krejcha Date: Sat, 12 Apr 2025 10:27:16 -0700 Subject: [PATCH] jwcrypto: type most of the rest of `JWT` and `JWKSet.generate` function (#13807) --- stubs/jwcrypto/jwcrypto/jwk.pyi | 34 +++++++++++++++++++++++++++------ stubs/jwcrypto/jwcrypto/jwt.pyi | 22 ++++++++++----------- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/stubs/jwcrypto/jwcrypto/jwk.pyi b/stubs/jwcrypto/jwcrypto/jwk.pyi index 9b3d912c3..9ec8c3866 100644 --- a/stubs/jwcrypto/jwcrypto/jwk.pyi +++ b/stubs/jwcrypto/jwcrypto/jwk.pyi @@ -1,7 +1,7 @@ from collections.abc import Callable, Sequence from enum import Enum from typing import Any, Literal, NamedTuple, TypeVar, overload -from typing_extensions import Self, deprecated +from typing_extensions import Self, TypeAlias, deprecated from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ec, rsa @@ -46,7 +46,8 @@ class _X448_CURVE(NamedTuple): pubkey: UnimplementedOKPCurveKey privkey: UnimplementedOKPCurveKey -JWKTypesRegistry: dict[str, str] +_JWKKeyTypeSupported: TypeAlias = Literal["oct", "RSA", "EC", "OKP"] +JWKTypesRegistry: dict[_JWKKeyTypeSupported, str] class ParmType(Enum): name = "A string with a name" # pyright: ignore[reportAssignmentType] @@ -63,8 +64,12 @@ class JWKParameter(NamedTuple): JWKValuesRegistry: dict[str, dict[str, JWKParameter]] JWKParamsRegistry: dict[str, JWKParameter] JWKEllipticCurveRegistry: dict[str, str] -JWKUseRegistry: dict[str, str] -JWKOperationsRegistry: dict[str, str] +_JWKUseSupported: TypeAlias = Literal["sig", "enc"] +JWKUseRegistry: dict[_JWKUseSupported, str] +_JWKOperationSupported: TypeAlias = Literal[ + "sign", "verify", "encrypt", "decrypt", "wrapKey", "unwrapKey", "deriveKey", "deriveBits" +] +JWKOperationsRegistry: dict[_JWKOperationSupported, str] JWKpycaCurveMap: dict[str, str] IANANamedInformationHashAlgorithmRegistry: dict[ str, @@ -98,9 +103,26 @@ class InvalidJWKValue(JWException): ... class JWK(dict[str, Any]): def __init__(self, **kwargs) -> None: ... + # `kty` and the other keyword arguments are passed as `params` to the called generator + # function. The possible arguments depend on the value of `kty`. + # TODO: Add overloads for the individual `kty` values. @classmethod - def generate(cls, **kwargs) -> Self: ... - def generate_key(self, **params) -> None: ... + @overload + def generate( + cls, + *, + kty: Literal["RSA"], + public_exponent: int | None = None, + size: int | None = None, + kid: str | None = None, + alg: str | None = None, + use: _JWKUseSupported | None = None, + key_ops: list[_JWKOperationSupported] | None = None, + ) -> Self: ... + @classmethod + @overload + def generate(cls, *, kty: _JWKKeyTypeSupported, **kwargs) -> Self: ... + def generate_key(self, *, kty: _JWKKeyTypeSupported, **kwargs) -> None: ... def import_key(self, **kwargs) -> None: ... @classmethod def from_json(cls, key) -> Self: ... diff --git a/stubs/jwcrypto/jwcrypto/jwt.pyi b/stubs/jwcrypto/jwcrypto/jwt.pyi index 0ede9c5c6..eb3f062e2 100644 --- a/stubs/jwcrypto/jwcrypto/jwt.pyi +++ b/stubs/jwcrypto/jwcrypto/jwt.pyi @@ -1,6 +1,6 @@ from _typeshed import Incomplete from collections.abc import Mapping -from typing import Any +from typing import Any, SupportsInt from typing_extensions import deprecated from jwcrypto.common import JWException, JWKeyNotFound @@ -49,31 +49,31 @@ class JWT: @header.setter def header(self, h: dict[str, Any] | str) -> None: ... @property - def claims(self): ... + def claims(self) -> str: ... @claims.setter - def claims(self, data) -> None: ... + def claims(self, data: str) -> None: ... @property def token(self): ... @token.setter def token(self, t) -> None: ... @property - def leeway(self): ... + def leeway(self) -> int: ... @leeway.setter - def leeway(self, lwy) -> None: ... + def leeway(self, lwy: SupportsInt) -> None: ... @property - def validity(self): ... + def validity(self) -> int: ... @validity.setter - def validity(self, v) -> None: ... + def validity(self, v: SupportsInt) -> None: ... @property def expected_type(self): ... @expected_type.setter def expected_type(self, v) -> None: ... def norm_typ(self, val): ... - def make_signed_token(self, key) -> None: ... - def make_encrypted_token(self, key) -> None: ... - def validate(self, key) -> None: ... + def make_signed_token(self, key: JWK) -> None: ... + def make_encrypted_token(self, key: JWK) -> None: ... + def validate(self, key: JWK | JWKSet) -> None: ... def deserialize(self, jwt, key: Incomplete | None = None) -> None: ... - def serialize(self, compact: bool = True): ... + def serialize(self, compact: bool = True) -> str: ... @classmethod def from_jose_token(cls, token): ... def __eq__(self, other: object) -> bool: ...