From 56db30e8eb4b456e892449491558af81a9eb201f Mon Sep 17 00:00:00 2001 From: Sebastian Rittau Date: Tue, 23 Apr 2024 17:41:10 +0200 Subject: [PATCH] Make email.message.Message generic over the header type (#11732) Co-authored-by: Avasam --- stdlib/email/message.pyi | 53 ++++++++++++++++++++-------------------- stdlib/email/parser.pyi | 40 ++++++++++++++++++------------ stdlib/http/client.pyi | 28 ++++++--------------- 3 files changed, 58 insertions(+), 63 deletions(-) diff --git a/stdlib/email/message.pyi b/stdlib/email/message.pyi index d7d7e8c8e..4032bc613 100644 --- a/stdlib/email/message.pyi +++ b/stdlib/email/message.pyi @@ -3,22 +3,25 @@ from email import _ParamsType, _ParamType from email.charset import Charset from email.contentmanager import ContentManager from email.errors import MessageDefect -from email.header import Header from email.policy import Policy -from typing import Any, Literal, Protocol, TypeVar, overload +from typing import Any, Generic, Literal, Protocol, TypeVar, overload from typing_extensions import Self, TypeAlias __all__ = ["Message", "EmailMessage"] _T = TypeVar("_T") +# Type returned by Policy.header_fetch_parse, often str or Header. +_HeaderT = TypeVar("_HeaderT", default=str) +_HeaderParamT = TypeVar("_HeaderParamT", default=str) +# Represents headers constructed by HeaderRegistry. Those are sub-classes +# of BaseHeader and another header type. +_HeaderRegistryT = TypeVar("_HeaderRegistryT", default=Any) +_HeaderRegistryParamT = TypeVar("_HeaderRegistryParamT", default=Any) + _PayloadType: TypeAlias = Message | str _EncodedPayloadType: TypeAlias = Message | bytes _MultipartPayloadType: TypeAlias = list[_PayloadType] _CharsetType: TypeAlias = Charset | str | None -# Type returned by Policy.header_fetch_parse, often str or Header. -_HeaderType: TypeAlias = Any -# Type accepted by Policy.header_store_parse. -_HeaderTypeParam: TypeAlias = str | Header | Any class _SupportsEncodeToPayload(Protocol): def encode(self, encoding: str, /) -> _PayloadType | _MultipartPayloadType | _SupportsDecodeToPayload: ... @@ -26,10 +29,7 @@ class _SupportsEncodeToPayload(Protocol): class _SupportsDecodeToPayload(Protocol): def decode(self, encoding: str, errors: str, /) -> _PayloadType | _MultipartPayloadType: ... -# TODO: This class should be generic over the header policy and/or the header -# value types allowed by the policy. This depends on PEP 696 support -# (https://github.com/python/typeshed/issues/11422). -class Message: +class Message(Generic[_HeaderT, _HeaderParamT]): policy: Policy # undocumented preamble: str | None epilogue: str | None @@ -70,24 +70,23 @@ class Message: # Same as `get` with `failobj=None`, but with the expectation that it won't return None in most scenarios # This is important for protocols using __getitem__, like SupportsKeysAndGetItem # Morally, the return type should be `AnyOf[_HeaderType, None]`, - # which we could spell as `_HeaderType | Any`, - # *but* `_HeaderType` itself is currently an alias to `Any`... - def __getitem__(self, name: str) -> _HeaderType: ... - def __setitem__(self, name: str, val: _HeaderTypeParam) -> None: ... + # so using "the Any trick" instead. + def __getitem__(self, name: str) -> _HeaderT | Any: ... + def __setitem__(self, name: str, val: _HeaderParamT) -> None: ... def __delitem__(self, name: str) -> None: ... def keys(self) -> list[str]: ... - def values(self) -> list[_HeaderType]: ... - def items(self) -> list[tuple[str, _HeaderType]]: ... + def values(self) -> list[_HeaderT]: ... + def items(self) -> list[tuple[str, _HeaderT]]: ... @overload - def get(self, name: str, failobj: None = None) -> _HeaderType | None: ... + def get(self, name: str, failobj: None = None) -> _HeaderT | None: ... @overload - def get(self, name: str, failobj: _T) -> _HeaderType | _T: ... + def get(self, name: str, failobj: _T) -> _HeaderT | _T: ... @overload - def get_all(self, name: str, failobj: None = None) -> list[_HeaderType] | None: ... + def get_all(self, name: str, failobj: None = None) -> list[_HeaderT] | None: ... @overload - def get_all(self, name: str, failobj: _T) -> list[_HeaderType] | _T: ... + def get_all(self, name: str, failobj: _T) -> list[_HeaderT] | _T: ... def add_header(self, _name: str, _value: str, **_params: _ParamsType) -> None: ... - def replace_header(self, _name: str, _value: _HeaderTypeParam) -> None: ... + def replace_header(self, _name: str, _value: _HeaderParamT) -> None: ... def get_content_type(self) -> str: ... def get_content_maintype(self) -> str: ... def get_content_subtype(self) -> str: ... @@ -141,14 +140,14 @@ class Message: ) -> None: ... def __init__(self, policy: Policy = ...) -> None: ... # The following two methods are undocumented, but a source code comment states that they are public API - def set_raw(self, name: str, value: _HeaderTypeParam) -> None: ... - def raw_items(self) -> Iterator[tuple[str, _HeaderType]]: ... + def set_raw(self, name: str, value: _HeaderParamT) -> None: ... + def raw_items(self) -> Iterator[tuple[str, _HeaderT]]: ... -class MIMEPart(Message): +class MIMEPart(Message[_HeaderRegistryT, _HeaderRegistryParamT]): def __init__(self, policy: Policy | None = None) -> None: ... - def get_body(self, preferencelist: Sequence[str] = ("related", "html", "plain")) -> Message | None: ... - def iter_attachments(self) -> Iterator[Message]: ... - def iter_parts(self) -> Iterator[Message]: ... + def get_body(self, preferencelist: Sequence[str] = ("related", "html", "plain")) -> MIMEPart[_HeaderRegistryT] | None: ... + def iter_attachments(self) -> Iterator[MIMEPart[_HeaderRegistryT]]: ... + def iter_parts(self) -> Iterator[MIMEPart[_HeaderRegistryT]]: ... def get_content(self, *args: Any, content_manager: ContentManager | None = None, **kw: Any) -> Any: ... def set_content(self, *args: Any, content_manager: ContentManager | None = None, **kw: Any) -> None: ... def make_related(self, boundary: str | None = None) -> None: ... diff --git a/stdlib/email/parser.pyi b/stdlib/email/parser.pyi index 28b6aca85..fecb29d90 100644 --- a/stdlib/email/parser.pyi +++ b/stdlib/email/parser.pyi @@ -3,24 +3,34 @@ from collections.abc import Callable from email.feedparser import BytesFeedParser as BytesFeedParser, FeedParser as FeedParser from email.message import Message from email.policy import Policy -from typing import IO +from io import _WrappedBuffer +from typing import Generic, TypeVar, overload __all__ = ["Parser", "HeaderParser", "BytesParser", "BytesHeaderParser", "FeedParser", "BytesFeedParser"] -class Parser: - def __init__(self, _class: Callable[[], Message] | None = None, *, policy: Policy = ...) -> None: ... - def parse(self, fp: SupportsRead[str], headersonly: bool = False) -> Message: ... - def parsestr(self, text: str, headersonly: bool = False) -> Message: ... +_MessageT = TypeVar("_MessageT", bound=Message, default=Message) -class HeaderParser(Parser): - def parse(self, fp: SupportsRead[str], headersonly: bool = True) -> Message: ... - def parsestr(self, text: str, headersonly: bool = True) -> Message: ... +class Parser(Generic[_MessageT]): + @overload + def __init__(self: Parser[Message[str, str]], _class: None = None, *, policy: Policy = ...) -> None: ... + @overload + def __init__(self, _class: Callable[[], _MessageT], *, policy: Policy = ...) -> None: ... + def parse(self, fp: SupportsRead[str], headersonly: bool = False) -> _MessageT: ... + def parsestr(self, text: str, headersonly: bool = False) -> _MessageT: ... -class BytesParser: - def __init__(self, _class: Callable[[], Message] = ..., *, policy: Policy = ...) -> None: ... - def parse(self, fp: IO[bytes], headersonly: bool = False) -> Message: ... - def parsebytes(self, text: bytes | bytearray, headersonly: bool = False) -> Message: ... +class HeaderParser(Parser[_MessageT]): + def parse(self, fp: SupportsRead[str], headersonly: bool = True) -> _MessageT: ... + def parsestr(self, text: str, headersonly: bool = True) -> _MessageT: ... -class BytesHeaderParser(BytesParser): - def parse(self, fp: IO[bytes], headersonly: bool = True) -> Message: ... - def parsebytes(self, text: bytes | bytearray, headersonly: bool = True) -> Message: ... +class BytesParser(Generic[_MessageT]): + parser: Parser[_MessageT] + @overload + def __init__(self: BytesParser[Message[str, str]], _class: None = None, *, policy: Policy = ...) -> None: ... + @overload + def __init__(self, _class: Callable[[], _MessageT], *, policy: Policy = ...) -> None: ... + def parse(self, fp: _WrappedBuffer, headersonly: bool = False) -> _MessageT: ... + def parsebytes(self, text: bytes | bytearray, headersonly: bool = False) -> _MessageT: ... + +class BytesHeaderParser(BytesParser[_MessageT]): + def parse(self, fp: _WrappedBuffer, headersonly: bool = True) -> _MessageT: ... + def parsebytes(self, text: bytes | bytearray, headersonly: bool = True) -> _MessageT: ... diff --git a/stdlib/http/client.pyi b/stdlib/http/client.pyi index fb5450730..f68d9d0ca 100644 --- a/stdlib/http/client.pyi +++ b/stdlib/http/client.pyi @@ -3,7 +3,7 @@ import io import ssl import sys import types -from _typeshed import ReadableBuffer, SupportsRead, WriteableBuffer +from _typeshed import ReadableBuffer, SupportsRead, SupportsReadline, WriteableBuffer from collections.abc import Callable, Iterable, Iterator, Mapping from socket import socket from typing import Any, BinaryIO, TypeVar, overload @@ -33,6 +33,7 @@ __all__ = [ _DataType: TypeAlias = SupportsRead[bytes] | Iterable[ReadableBuffer] | ReadableBuffer _T = TypeVar("_T") +_MessageT = TypeVar("_MessageT", bound=email.message.Message) HTTP_PORT: int HTTPS_PORT: int @@ -97,28 +98,13 @@ NETWORK_AUTHENTICATION_REQUIRED: int responses: dict[int, str] -class HTTPMessage(email.message.Message): +class HTTPMessage(email.message.Message[str, str]): def getallmatchingheaders(self, name: str) -> list[str]: ... # undocumented - # override below all of Message's methods that use `_HeaderType` / `_HeaderTypeParam` with `str` - # `HTTPMessage` breaks the Liskov substitution principle by only intending for `str` headers - # This is easier than making `Message` generic - def __getitem__(self, name: str) -> str | None: ... - def __setitem__(self, name: str, val: str) -> None: ... # type: ignore[override] - def values(self) -> list[str]: ... - def items(self) -> list[tuple[str, str]]: ... - @overload - def get(self, name: str, failobj: None = None) -> str | None: ... - @overload - def get(self, name: str, failobj: _T) -> str | _T: ... - @overload - def get_all(self, name: str, failobj: None = None) -> list[str] | None: ... - @overload - def get_all(self, name: str, failobj: _T) -> list[str] | _T: ... - def replace_header(self, _name: str, _value: str) -> None: ... # type: ignore[override] - def set_raw(self, name: str, value: str) -> None: ... # type: ignore[override] - def raw_items(self) -> Iterator[tuple[str, str]]: ... -def parse_headers(fp: io.BufferedIOBase, _class: Callable[[], email.message.Message] = ...) -> HTTPMessage: ... +@overload +def parse_headers(fp: SupportsReadline[bytes], _class: Callable[[], _MessageT]) -> _MessageT: ... +@overload +def parse_headers(fp: SupportsReadline[bytes]) -> HTTPMessage: ... class HTTPResponse(io.BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible method definitions in the base classes msg: HTTPMessage