Make email.message.Message generic over the header type (#11732)

Co-authored-by: Avasam <samuel.06@hotmail.com>
This commit is contained in:
Sebastian Rittau
2024-04-23 17:41:10 +02:00
committed by GitHub
parent c13f6e166d
commit 56db30e8eb
3 changed files with 58 additions and 63 deletions

View File

@@ -3,22 +3,25 @@ from email import _ParamsType, _ParamType
from email.charset import Charset
from email.contentmanager import ContentManager
from email.errors import MessageDefect
from email.header import Header
from email.policy import Policy
from typing import Any, Literal, Protocol, TypeVar, overload
from typing import Any, Generic, Literal, Protocol, TypeVar, overload
from typing_extensions import Self, TypeAlias
__all__ = ["Message", "EmailMessage"]
_T = TypeVar("_T")
# Type returned by Policy.header_fetch_parse, often str or Header.
_HeaderT = TypeVar("_HeaderT", default=str)
_HeaderParamT = TypeVar("_HeaderParamT", default=str)
# Represents headers constructed by HeaderRegistry. Those are sub-classes
# of BaseHeader and another header type.
_HeaderRegistryT = TypeVar("_HeaderRegistryT", default=Any)
_HeaderRegistryParamT = TypeVar("_HeaderRegistryParamT", default=Any)
_PayloadType: TypeAlias = Message | str
_EncodedPayloadType: TypeAlias = Message | bytes
_MultipartPayloadType: TypeAlias = list[_PayloadType]
_CharsetType: TypeAlias = Charset | str | None
# Type returned by Policy.header_fetch_parse, often str or Header.
_HeaderType: TypeAlias = Any
# Type accepted by Policy.header_store_parse.
_HeaderTypeParam: TypeAlias = str | Header | Any
class _SupportsEncodeToPayload(Protocol):
def encode(self, encoding: str, /) -> _PayloadType | _MultipartPayloadType | _SupportsDecodeToPayload: ...
@@ -26,10 +29,7 @@ class _SupportsEncodeToPayload(Protocol):
class _SupportsDecodeToPayload(Protocol):
def decode(self, encoding: str, errors: str, /) -> _PayloadType | _MultipartPayloadType: ...
# TODO: This class should be generic over the header policy and/or the header
# value types allowed by the policy. This depends on PEP 696 support
# (https://github.com/python/typeshed/issues/11422).
class Message:
class Message(Generic[_HeaderT, _HeaderParamT]):
policy: Policy # undocumented
preamble: str | None
epilogue: str | None
@@ -70,24 +70,23 @@ class Message:
# Same as `get` with `failobj=None`, but with the expectation that it won't return None in most scenarios
# This is important for protocols using __getitem__, like SupportsKeysAndGetItem
# Morally, the return type should be `AnyOf[_HeaderType, None]`,
# which we could spell as `_HeaderType | Any`,
# *but* `_HeaderType` itself is currently an alias to `Any`...
def __getitem__(self, name: str) -> _HeaderType: ...
def __setitem__(self, name: str, val: _HeaderTypeParam) -> None: ...
# so using "the Any trick" instead.
def __getitem__(self, name: str) -> _HeaderT | Any: ...
def __setitem__(self, name: str, val: _HeaderParamT) -> None: ...
def __delitem__(self, name: str) -> None: ...
def keys(self) -> list[str]: ...
def values(self) -> list[_HeaderType]: ...
def items(self) -> list[tuple[str, _HeaderType]]: ...
def values(self) -> list[_HeaderT]: ...
def items(self) -> list[tuple[str, _HeaderT]]: ...
@overload
def get(self, name: str, failobj: None = None) -> _HeaderType | None: ...
def get(self, name: str, failobj: None = None) -> _HeaderT | None: ...
@overload
def get(self, name: str, failobj: _T) -> _HeaderType | _T: ...
def get(self, name: str, failobj: _T) -> _HeaderT | _T: ...
@overload
def get_all(self, name: str, failobj: None = None) -> list[_HeaderType] | None: ...
def get_all(self, name: str, failobj: None = None) -> list[_HeaderT] | None: ...
@overload
def get_all(self, name: str, failobj: _T) -> list[_HeaderType] | _T: ...
def get_all(self, name: str, failobj: _T) -> list[_HeaderT] | _T: ...
def add_header(self, _name: str, _value: str, **_params: _ParamsType) -> None: ...
def replace_header(self, _name: str, _value: _HeaderTypeParam) -> None: ...
def replace_header(self, _name: str, _value: _HeaderParamT) -> None: ...
def get_content_type(self) -> str: ...
def get_content_maintype(self) -> str: ...
def get_content_subtype(self) -> str: ...
@@ -141,14 +140,14 @@ class Message:
) -> None: ...
def __init__(self, policy: Policy = ...) -> None: ...
# The following two methods are undocumented, but a source code comment states that they are public API
def set_raw(self, name: str, value: _HeaderTypeParam) -> None: ...
def raw_items(self) -> Iterator[tuple[str, _HeaderType]]: ...
def set_raw(self, name: str, value: _HeaderParamT) -> None: ...
def raw_items(self) -> Iterator[tuple[str, _HeaderT]]: ...
class MIMEPart(Message):
class MIMEPart(Message[_HeaderRegistryT, _HeaderRegistryParamT]):
def __init__(self, policy: Policy | None = None) -> None: ...
def get_body(self, preferencelist: Sequence[str] = ("related", "html", "plain")) -> Message | None: ...
def iter_attachments(self) -> Iterator[Message]: ...
def iter_parts(self) -> Iterator[Message]: ...
def get_body(self, preferencelist: Sequence[str] = ("related", "html", "plain")) -> MIMEPart[_HeaderRegistryT] | None: ...
def iter_attachments(self) -> Iterator[MIMEPart[_HeaderRegistryT]]: ...
def iter_parts(self) -> Iterator[MIMEPart[_HeaderRegistryT]]: ...
def get_content(self, *args: Any, content_manager: ContentManager | None = None, **kw: Any) -> Any: ...
def set_content(self, *args: Any, content_manager: ContentManager | None = None, **kw: Any) -> None: ...
def make_related(self, boundary: str | None = None) -> None: ...

View File

@@ -3,24 +3,34 @@ from collections.abc import Callable
from email.feedparser import BytesFeedParser as BytesFeedParser, FeedParser as FeedParser
from email.message import Message
from email.policy import Policy
from typing import IO
from io import _WrappedBuffer
from typing import Generic, TypeVar, overload
__all__ = ["Parser", "HeaderParser", "BytesParser", "BytesHeaderParser", "FeedParser", "BytesFeedParser"]
class Parser:
def __init__(self, _class: Callable[[], Message] | None = None, *, policy: Policy = ...) -> None: ...
def parse(self, fp: SupportsRead[str], headersonly: bool = False) -> Message: ...
def parsestr(self, text: str, headersonly: bool = False) -> Message: ...
_MessageT = TypeVar("_MessageT", bound=Message, default=Message)
class HeaderParser(Parser):
def parse(self, fp: SupportsRead[str], headersonly: bool = True) -> Message: ...
def parsestr(self, text: str, headersonly: bool = True) -> Message: ...
class Parser(Generic[_MessageT]):
@overload
def __init__(self: Parser[Message[str, str]], _class: None = None, *, policy: Policy = ...) -> None: ...
@overload
def __init__(self, _class: Callable[[], _MessageT], *, policy: Policy = ...) -> None: ...
def parse(self, fp: SupportsRead[str], headersonly: bool = False) -> _MessageT: ...
def parsestr(self, text: str, headersonly: bool = False) -> _MessageT: ...
class BytesParser:
def __init__(self, _class: Callable[[], Message] = ..., *, policy: Policy = ...) -> None: ...
def parse(self, fp: IO[bytes], headersonly: bool = False) -> Message: ...
def parsebytes(self, text: bytes | bytearray, headersonly: bool = False) -> Message: ...
class HeaderParser(Parser[_MessageT]):
def parse(self, fp: SupportsRead[str], headersonly: bool = True) -> _MessageT: ...
def parsestr(self, text: str, headersonly: bool = True) -> _MessageT: ...
class BytesHeaderParser(BytesParser):
def parse(self, fp: IO[bytes], headersonly: bool = True) -> Message: ...
def parsebytes(self, text: bytes | bytearray, headersonly: bool = True) -> Message: ...
class BytesParser(Generic[_MessageT]):
parser: Parser[_MessageT]
@overload
def __init__(self: BytesParser[Message[str, str]], _class: None = None, *, policy: Policy = ...) -> None: ...
@overload
def __init__(self, _class: Callable[[], _MessageT], *, policy: Policy = ...) -> None: ...
def parse(self, fp: _WrappedBuffer, headersonly: bool = False) -> _MessageT: ...
def parsebytes(self, text: bytes | bytearray, headersonly: bool = False) -> _MessageT: ...
class BytesHeaderParser(BytesParser[_MessageT]):
def parse(self, fp: _WrappedBuffer, headersonly: bool = True) -> _MessageT: ...
def parsebytes(self, text: bytes | bytearray, headersonly: bool = True) -> _MessageT: ...

View File

@@ -3,7 +3,7 @@ import io
import ssl
import sys
import types
from _typeshed import ReadableBuffer, SupportsRead, WriteableBuffer
from _typeshed import ReadableBuffer, SupportsRead, SupportsReadline, WriteableBuffer
from collections.abc import Callable, Iterable, Iterator, Mapping
from socket import socket
from typing import Any, BinaryIO, TypeVar, overload
@@ -33,6 +33,7 @@ __all__ = [
_DataType: TypeAlias = SupportsRead[bytes] | Iterable[ReadableBuffer] | ReadableBuffer
_T = TypeVar("_T")
_MessageT = TypeVar("_MessageT", bound=email.message.Message)
HTTP_PORT: int
HTTPS_PORT: int
@@ -97,28 +98,13 @@ NETWORK_AUTHENTICATION_REQUIRED: int
responses: dict[int, str]
class HTTPMessage(email.message.Message):
class HTTPMessage(email.message.Message[str, str]):
def getallmatchingheaders(self, name: str) -> list[str]: ... # undocumented
# override below all of Message's methods that use `_HeaderType` / `_HeaderTypeParam` with `str`
# `HTTPMessage` breaks the Liskov substitution principle by only intending for `str` headers
# This is easier than making `Message` generic
def __getitem__(self, name: str) -> str | None: ...
def __setitem__(self, name: str, val: str) -> None: ... # type: ignore[override]
def values(self) -> list[str]: ...
def items(self) -> list[tuple[str, str]]: ...
@overload
def get(self, name: str, failobj: None = None) -> str | None: ...
@overload
def get(self, name: str, failobj: _T) -> str | _T: ...
@overload
def get_all(self, name: str, failobj: None = None) -> list[str] | None: ...
@overload
def get_all(self, name: str, failobj: _T) -> list[str] | _T: ...
def replace_header(self, _name: str, _value: str) -> None: ... # type: ignore[override]
def set_raw(self, name: str, value: str) -> None: ... # type: ignore[override]
def raw_items(self) -> Iterator[tuple[str, str]]: ...
def parse_headers(fp: io.BufferedIOBase, _class: Callable[[], email.message.Message] = ...) -> HTTPMessage: ...
@overload
def parse_headers(fp: SupportsReadline[bytes], _class: Callable[[], _MessageT]) -> _MessageT: ...
@overload
def parse_headers(fp: SupportsReadline[bytes]) -> HTTPMessage: ...
class HTTPResponse(io.BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible method definitions in the base classes
msg: HTTPMessage