From e224b28f3295bf1c4f078dc071ee2c9d94a3b9b3 Mon Sep 17 00:00:00 2001 From: Sebastian Rittau Date: Tue, 13 May 2025 01:12:32 +0200 Subject: [PATCH] Fix variance of a few email-related classes (#13952) Closes #13919 --- stdlib/@tests/test_cases/email/check_mime.py | 4 +++ stdlib/email/_policybase.pyi | 19 +++++----- stdlib/email/message.pyi | 38 ++++++++++---------- stdlib/email/mime/text.pyi | 2 +- 4 files changed, 35 insertions(+), 28 deletions(-) create mode 100644 stdlib/@tests/test_cases/email/check_mime.py diff --git a/stdlib/@tests/test_cases/email/check_mime.py b/stdlib/@tests/test_cases/email/check_mime.py new file mode 100644 index 000000000..e49d2bfac --- /dev/null +++ b/stdlib/@tests/test_cases/email/check_mime.py @@ -0,0 +1,4 @@ +from email.mime.text import MIMEText +from email.policy import SMTP + +msg = MIMEText("", policy=SMTP) diff --git a/stdlib/email/_policybase.pyi b/stdlib/email/_policybase.pyi index b345c84a9..0fb890d42 100644 --- a/stdlib/email/_policybase.pyi +++ b/stdlib/email/_policybase.pyi @@ -8,6 +8,7 @@ from typing_extensions import Self __all__ = ["Policy", "Compat32", "compat32"] _MessageT = TypeVar("_MessageT", bound=Message[Any, Any], default=Message[str, str]) +_MessageT_co = TypeVar("_MessageT_co", covariant=True, bound=Message[Any, Any], default=Message[str, str]) @type_check_only class _MessageFactory(Protocol[_MessageT]): @@ -16,13 +17,13 @@ class _MessageFactory(Protocol[_MessageT]): # Policy below is the only known direct subclass of _PolicyBase. We therefore # assume that the __init__ arguments and attributes of _PolicyBase are # the same as those of Policy. -class _PolicyBase(Generic[_MessageT]): +class _PolicyBase(Generic[_MessageT_co]): max_line_length: int | None linesep: str cte_type: str raise_on_defect: bool mangle_from_: bool - message_factory: _MessageFactory[_MessageT] | None + message_factory: _MessageFactory[_MessageT_co] | None # Added in Python 3.9.20, 3.10.15, 3.11.10, 3.12.5 verify_generated_headers: bool @@ -34,7 +35,7 @@ class _PolicyBase(Generic[_MessageT]): cte_type: str = "8bit", raise_on_defect: bool = False, mangle_from_: bool = ..., # default depends on sub-class - message_factory: _MessageFactory[_MessageT] | None = None, + message_factory: _MessageFactory[_MessageT_co] | None = None, # Added in Python 3.9.20, 3.10.15, 3.11.10, 3.12.5 verify_generated_headers: bool = True, ) -> None: ... @@ -46,15 +47,17 @@ class _PolicyBase(Generic[_MessageT]): cte_type: str = ..., raise_on_defect: bool = ..., mangle_from_: bool = ..., - message_factory: _MessageFactory[_MessageT] | None = ..., + message_factory: _MessageFactory[_MessageT_co] | None = ..., # Added in Python 3.9.20, 3.10.15, 3.11.10, 3.12.5 verify_generated_headers: bool = ..., ) -> Self: ... def __add__(self, other: Policy) -> Self: ... -class Policy(_PolicyBase[_MessageT], metaclass=ABCMeta): - def handle_defect(self, obj: _MessageT, defect: MessageDefect) -> None: ... - def register_defect(self, obj: _MessageT, defect: MessageDefect) -> None: ... +class Policy(_PolicyBase[_MessageT_co], metaclass=ABCMeta): + # Every Message object has a `defects` attribute, so the following + # methods will work for any Message object. + def handle_defect(self, obj: Message[Any, Any], defect: MessageDefect) -> None: ... + def register_defect(self, obj: Message[Any, Any], defect: MessageDefect) -> None: ... def header_max_count(self, name: str) -> int | None: ... @abstractmethod def header_source_parse(self, sourcelines: list[str]) -> tuple[str, str]: ... @@ -67,7 +70,7 @@ class Policy(_PolicyBase[_MessageT], metaclass=ABCMeta): @abstractmethod def fold_binary(self, name: str, value: str) -> bytes: ... -class Compat32(Policy[_MessageT]): +class Compat32(Policy[_MessageT_co]): def header_source_parse(self, sourcelines: list[str]) -> tuple[str, str]: ... def header_store_parse(self, name: str, value: str) -> tuple[str, str]: ... def header_fetch_parse(self, name: str, value: str) -> str | Header: ... # type: ignore[override] diff --git a/stdlib/email/message.pyi b/stdlib/email/message.pyi index ebad05a1c..e4d149921 100644 --- a/stdlib/email/message.pyi +++ b/stdlib/email/message.pyi @@ -12,12 +12,12 @@ __all__ = ["Message", "EmailMessage"] _T = TypeVar("_T") # Type returned by Policy.header_fetch_parse, often str or Header. -_HeaderT = TypeVar("_HeaderT", default=str) -_HeaderParamT = TypeVar("_HeaderParamT", default=str) +_HeaderT_co = TypeVar("_HeaderT_co", covariant=True, default=str) +_HeaderParamT_contra = TypeVar("_HeaderParamT_contra", contravariant=True, default=str) # Represents headers constructed by HeaderRegistry. Those are sub-classes # of BaseHeader and another header type. -_HeaderRegistryT = TypeVar("_HeaderRegistryT", default=Any) -_HeaderRegistryParamT = TypeVar("_HeaderRegistryParamT", default=Any) +_HeaderRegistryT_co = TypeVar("_HeaderRegistryT_co", covariant=True, default=Any) +_HeaderRegistryParamT_contra = TypeVar("_HeaderRegistryParamT_contra", contravariant=True, default=Any) _PayloadType: TypeAlias = Message | str _EncodedPayloadType: TypeAlias = Message | bytes @@ -30,7 +30,7 @@ class _SupportsEncodeToPayload(Protocol): class _SupportsDecodeToPayload(Protocol): def decode(self, encoding: str, errors: str, /) -> _PayloadType | _MultipartPayloadType: ... -class Message(Generic[_HeaderT, _HeaderParamT]): +class Message(Generic[_HeaderT_co, _HeaderParamT_contra]): # The policy attributes and arguments in this class and its subclasses # would ideally use Policy[Self], but this is not possible. policy: Policy[Any] # undocumented @@ -76,22 +76,22 @@ class Message(Generic[_HeaderT, _HeaderParamT]): # This is important for protocols using __getitem__, like SupportsKeysAndGetItem # Morally, the return type should be `AnyOf[_HeaderType, None]`, # so using "the Any trick" instead. - def __getitem__(self, name: str) -> _HeaderT | MaybeNone: ... - def __setitem__(self, name: str, val: _HeaderParamT) -> None: ... + def __getitem__(self, name: str) -> _HeaderT_co | MaybeNone: ... + def __setitem__(self, name: str, val: _HeaderParamT_contra) -> None: ... def __delitem__(self, name: str) -> None: ... def keys(self) -> list[str]: ... - def values(self) -> list[_HeaderT]: ... - def items(self) -> list[tuple[str, _HeaderT]]: ... + def values(self) -> list[_HeaderT_co]: ... + def items(self) -> list[tuple[str, _HeaderT_co]]: ... @overload - def get(self, name: str, failobj: None = None) -> _HeaderT | None: ... + def get(self, name: str, failobj: None = None) -> _HeaderT_co | None: ... @overload - def get(self, name: str, failobj: _T) -> _HeaderT | _T: ... + def get(self, name: str, failobj: _T) -> _HeaderT_co | _T: ... @overload - def get_all(self, name: str, failobj: None = None) -> list[_HeaderT] | None: ... + def get_all(self, name: str, failobj: None = None) -> list[_HeaderT_co] | None: ... @overload - def get_all(self, name: str, failobj: _T) -> list[_HeaderT] | _T: ... + def get_all(self, name: str, failobj: _T) -> list[_HeaderT_co] | _T: ... def add_header(self, _name: str, _value: str, **_params: _ParamsType) -> None: ... - def replace_header(self, _name: str, _value: _HeaderParamT) -> None: ... + def replace_header(self, _name: str, _value: _HeaderParamT_contra) -> None: ... def get_content_type(self) -> str: ... def get_content_maintype(self) -> str: ... def get_content_subtype(self) -> str: ... @@ -144,18 +144,18 @@ class Message(Generic[_HeaderT, _HeaderParamT]): replace: bool = False, ) -> None: ... # The following two methods are undocumented, but a source code comment states that they are public API - def set_raw(self, name: str, value: _HeaderParamT) -> None: ... - def raw_items(self) -> Iterator[tuple[str, _HeaderT]]: ... + def set_raw(self, name: str, value: _HeaderParamT_contra) -> None: ... + def raw_items(self) -> Iterator[tuple[str, _HeaderT_co]]: ... -class MIMEPart(Message[_HeaderRegistryT, _HeaderRegistryParamT]): +class MIMEPart(Message[_HeaderRegistryT_co, _HeaderRegistryParamT_contra]): def __init__(self, policy: Policy[Any] | None = None) -> None: ... - def get_body(self, preferencelist: Sequence[str] = ("related", "html", "plain")) -> MIMEPart[_HeaderRegistryT] | None: ... + def get_body(self, preferencelist: Sequence[str] = ("related", "html", "plain")) -> MIMEPart[_HeaderRegistryT_co] | None: ... def attach(self, payload: Self) -> None: ... # type: ignore[override] # The attachments are created via type(self) in the attach method. It's theoretically # possible to sneak other attachment types into a MIMEPart instance, but could cause # cause unforseen consequences. def iter_attachments(self) -> Iterator[Self]: ... - def iter_parts(self) -> Iterator[MIMEPart[_HeaderRegistryT]]: ... + def iter_parts(self) -> Iterator[MIMEPart[_HeaderRegistryT_co]]: ... def get_content(self, *args: Any, content_manager: ContentManager | None = None, **kw: Any) -> Any: ... def set_content(self, *args: Any, content_manager: ContentManager | None = None, **kw: Any) -> None: ... def make_related(self, boundary: str | None = None) -> None: ... diff --git a/stdlib/email/mime/text.pyi b/stdlib/email/mime/text.pyi index 74d5ef4c5..edfa67a09 100644 --- a/stdlib/email/mime/text.pyi +++ b/stdlib/email/mime/text.pyi @@ -1,5 +1,5 @@ +from email._policybase import Policy from email.mime.nonmultipart import MIMENonMultipart -from email.policy import Policy __all__ = ["MIMEText"]