diff --git a/stdlib/@tests/test_cases/builtins/check_memoryview.py b/stdlib/@tests/test_cases/builtins/check_memoryview.py new file mode 100644 index 000000000..108fc8395 --- /dev/null +++ b/stdlib/@tests/test_cases/builtins/check_memoryview.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import array +from typing_extensions import assert_type + +# Casting to bytes. +buf = b"abcdefg" +view = memoryview(buf).cast("c") +elm = view[0] +assert_type(elm, bytes) +assert_type(view[0:2], memoryview[bytes]) + +# Casting to a bool. +a = array.array("B", [0, 1, 2, 3]) +mv = memoryview(a) +bool_mv = mv.cast("?") +assert_type(bool_mv[0], bool) +assert_type(bool_mv[0:2], memoryview[bool]) + + +# Casting to a signed char. +a = array.array("B", [0, 1, 2, 3]) +mv = memoryview(a) +signed_mv = mv.cast("b") +assert_type(signed_mv[0], int) +assert_type(signed_mv[0:2], memoryview[int]) + +# Casting to a signed short. +a = array.array("B", [0, 1, 2, 3]) +mv = memoryview(a) +signed_mv = mv.cast("h") +assert_type(signed_mv[0], int) +assert_type(signed_mv[0:2], memoryview[int]) + +# Casting to a signed int. +a = array.array("B", [0, 1, 2, 3]) +mv = memoryview(a) +signed_mv = mv.cast("i") +assert_type(signed_mv[0], int) +assert_type(signed_mv[0:2], memoryview[int]) + +# Casting to a signed long. +a = array.array("B", [0, 1, 2, 3]) +mv = memoryview(a) +signed_mv = mv.cast("l") +assert_type(signed_mv[0], int) +assert_type(signed_mv[0:2], memoryview[int]) + +# Casting to a float. +a = array.array("B", [0, 1, 2, 3]) +mv = memoryview(a) +float_mv = mv.cast("f") +assert_type(float_mv[0], float) +assert_type(float_mv[0:2], memoryview[float]) + +# An invalid literal should raise an error. +mv = memoryview(b"abc") +mv.cast("abc") # type: ignore diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index 2299e6ccb..6e0232f20 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -75,6 +75,7 @@ if sys.version_info >= (3, 9): from types import GenericAlias _T = TypeVar("_T") +_I = TypeVar("_I", default=int) _T_co = TypeVar("_T_co", covariant=True) _T_contra = TypeVar("_T_contra", contravariant=True) _R_co = TypeVar("_R_co", covariant=True) @@ -823,8 +824,12 @@ class bytearray(MutableSequence[int]): def __buffer__(self, flags: int, /) -> memoryview: ... def __release_buffer__(self, buffer: memoryview, /) -> None: ... +_IntegerFormats: TypeAlias = Literal[ + "b", "B", "@b", "@B", "h", "H", "@h", "@H", "i", "I", "@i", "@I", "l", "L", "@l", "@L", "q", "Q", "@q", "@Q", "P", "@P" +] + @final -class memoryview(Sequence[int]): +class memoryview(Sequence[_I]): @property def format(self) -> str: ... @property @@ -854,13 +859,20 @@ class memoryview(Sequence[int]): def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, / ) -> None: ... - def cast(self, format: str, shape: list[int] | tuple[int, ...] = ...) -> memoryview: ... @overload - def __getitem__(self, key: SupportsIndex | tuple[SupportsIndex, ...], /) -> int: ... + def cast(self, format: Literal["c", "@c"], shape: list[int] | tuple[int, ...] = ...) -> memoryview[bytes]: ... @overload - def __getitem__(self, key: slice, /) -> memoryview: ... + def cast(self, format: Literal["f", "@f", "d", "@d"], shape: list[int] | tuple[int, ...] = ...) -> memoryview[float]: ... + @overload + def cast(self, format: Literal["?"], shape: list[int] | tuple[int, ...] = ...) -> memoryview[bool]: ... + @overload + def cast(self, format: _IntegerFormats, shape: list[int] | tuple[int, ...] = ...) -> memoryview: ... + @overload + def __getitem__(self, key: SupportsIndex | tuple[SupportsIndex, ...], /) -> _I: ... + @overload + def __getitem__(self, key: slice, /) -> memoryview[_I]: ... def __contains__(self, x: object, /) -> bool: ... - def __iter__(self) -> Iterator[int]: ... + def __iter__(self) -> Iterator[_I]: ... def __len__(self) -> int: ... def __eq__(self, value: object, /) -> bool: ... def __hash__(self) -> int: ...