Use a TypeGuard for dataclasses.is_dataclass(); refine asdict(), astuple(), fields(), replace() (#9362)

This commit is contained in:
Alex Waygood
2023-01-28 15:14:22 +00:00
committed by GitHub
parent c216b74e39
commit 32ebe323f5
2 changed files with 100 additions and 9 deletions

View File

@@ -3,8 +3,8 @@ import sys
import types
from builtins import type as Type # alias to avoid name clashes with fields named "type"
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Generic, Protocol, TypeVar, overload
from typing_extensions import Literal, TypeAlias
from typing import Any, ClassVar, Generic, Protocol, TypeVar, overload
from typing_extensions import Literal, TypeAlias, TypeGuard
if sys.version_info >= (3, 9):
from types import GenericAlias
@@ -30,6 +30,11 @@ __all__ = [
if sys.version_info >= (3, 10):
__all__ += ["KW_ONLY"]
class _DataclassInstance(Protocol):
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
_DataclassT = TypeVar("_DataclassT", bound=_DataclassInstance)
# define _MISSING_TYPE as an enum within the type stubs,
# even though that is not really its type at runtime
# this allows us to use Literal[_MISSING_TYPE.MISSING]
@@ -44,13 +49,13 @@ if sys.version_info >= (3, 10):
class KW_ONLY: ...
@overload
def asdict(obj: Any) -> dict[str, Any]: ...
def asdict(obj: _DataclassInstance) -> dict[str, Any]: ...
@overload
def asdict(obj: Any, *, dict_factory: Callable[[list[tuple[str, Any]]], _T]) -> _T: ...
def asdict(obj: _DataclassInstance, *, dict_factory: Callable[[list[tuple[str, Any]]], _T]) -> _T: ...
@overload
def astuple(obj: Any) -> tuple[Any, ...]: ...
def astuple(obj: _DataclassInstance) -> tuple[Any, ...]: ...
@overload
def astuple(obj: Any, *, tuple_factory: Callable[[list[Any]], _T]) -> _T: ...
def astuple(obj: _DataclassInstance, *, tuple_factory: Callable[[list[Any]], _T]) -> _T: ...
if sys.version_info >= (3, 8):
# cls argument is now positional-only
@@ -212,8 +217,13 @@ else:
metadata: Mapping[Any, Any] | None = ...,
) -> Any: ...
def fields(class_or_instance: Any) -> tuple[Field[Any], ...]: ...
def is_dataclass(obj: Any) -> bool: ...
def fields(class_or_instance: _DataclassInstance | type[_DataclassInstance]) -> tuple[Field[Any], ...]: ...
@overload
def is_dataclass(obj: _DataclassInstance | type[_DataclassInstance]) -> Literal[True]: ...
@overload
def is_dataclass(obj: type) -> TypeGuard[type[_DataclassInstance]]: ...
@overload
def is_dataclass(obj: object) -> TypeGuard[_DataclassInstance | type[_DataclassInstance]]: ...
class FrozenInstanceError(AttributeError): ...
@@ -285,4 +295,4 @@ else:
frozen: bool = ...,
) -> type: ...
def replace(__obj: _T, **changes: Any) -> _T: ...
def replace(__obj: _DataclassT, **changes: Any) -> _DataclassT: ...