Make MemoryView Generic, make cast accurate (#12247)

This commit is contained in:
Max Muoto
2024-07-11 19:12:44 -05:00
committed by GitHub
parent 3b5b642fdb
commit 4316e00c9e
2 changed files with 75 additions and 5 deletions

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
import array
from typing_extensions import assert_type
# Casting to bytes.
buf = b"abcdefg"
view = memoryview(buf).cast("c")
elm = view[0]
assert_type(elm, bytes)
assert_type(view[0:2], memoryview[bytes])
# Casting to a bool.
a = array.array("B", [0, 1, 2, 3])
mv = memoryview(a)
bool_mv = mv.cast("?")
assert_type(bool_mv[0], bool)
assert_type(bool_mv[0:2], memoryview[bool])
# Casting to a signed char.
a = array.array("B", [0, 1, 2, 3])
mv = memoryview(a)
signed_mv = mv.cast("b")
assert_type(signed_mv[0], int)
assert_type(signed_mv[0:2], memoryview[int])
# Casting to a signed short.
a = array.array("B", [0, 1, 2, 3])
mv = memoryview(a)
signed_mv = mv.cast("h")
assert_type(signed_mv[0], int)
assert_type(signed_mv[0:2], memoryview[int])
# Casting to a signed int.
a = array.array("B", [0, 1, 2, 3])
mv = memoryview(a)
signed_mv = mv.cast("i")
assert_type(signed_mv[0], int)
assert_type(signed_mv[0:2], memoryview[int])
# Casting to a signed long.
a = array.array("B", [0, 1, 2, 3])
mv = memoryview(a)
signed_mv = mv.cast("l")
assert_type(signed_mv[0], int)
assert_type(signed_mv[0:2], memoryview[int])
# Casting to a float.
a = array.array("B", [0, 1, 2, 3])
mv = memoryview(a)
float_mv = mv.cast("f")
assert_type(float_mv[0], float)
assert_type(float_mv[0:2], memoryview[float])
# An invalid literal should raise an error.
mv = memoryview(b"abc")
mv.cast("abc") # type: ignore

View File

@@ -75,6 +75,7 @@ if sys.version_info >= (3, 9):
from types import GenericAlias
_T = TypeVar("_T")
_I = TypeVar("_I", default=int)
_T_co = TypeVar("_T_co", covariant=True)
_T_contra = TypeVar("_T_contra", contravariant=True)
_R_co = TypeVar("_R_co", covariant=True)
@@ -823,8 +824,12 @@ class bytearray(MutableSequence[int]):
def __buffer__(self, flags: int, /) -> memoryview: ...
def __release_buffer__(self, buffer: memoryview, /) -> None: ...
_IntegerFormats: TypeAlias = Literal[
"b", "B", "@b", "@B", "h", "H", "@h", "@H", "i", "I", "@i", "@I", "l", "L", "@l", "@L", "q", "Q", "@q", "@Q", "P", "@P"
]
@final
class memoryview(Sequence[int]):
class memoryview(Sequence[_I]):
@property
def format(self) -> str: ...
@property
@@ -854,13 +859,20 @@ class memoryview(Sequence[int]):
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, /
) -> None: ...
def cast(self, format: str, shape: list[int] | tuple[int, ...] = ...) -> memoryview: ...
@overload
def __getitem__(self, key: SupportsIndex | tuple[SupportsIndex, ...], /) -> int: ...
def cast(self, format: Literal["c", "@c"], shape: list[int] | tuple[int, ...] = ...) -> memoryview[bytes]: ...
@overload
def __getitem__(self, key: slice, /) -> memoryview: ...
def cast(self, format: Literal["f", "@f", "d", "@d"], shape: list[int] | tuple[int, ...] = ...) -> memoryview[float]: ...
@overload
def cast(self, format: Literal["?"], shape: list[int] | tuple[int, ...] = ...) -> memoryview[bool]: ...
@overload
def cast(self, format: _IntegerFormats, shape: list[int] | tuple[int, ...] = ...) -> memoryview: ...
@overload
def __getitem__(self, key: SupportsIndex | tuple[SupportsIndex, ...], /) -> _I: ...
@overload
def __getitem__(self, key: slice, /) -> memoryview[_I]: ...
def __contains__(self, x: object, /) -> bool: ...
def __iter__(self) -> Iterator[int]: ...
def __iter__(self) -> Iterator[_I]: ...
def __len__(self) -> int: ...
def __eq__(self, value: object, /) -> bool: ...
def __hash__(self) -> int: ...