diff --git a/stubs/tensorflow/@tests/stubtest_allowlist.txt b/stubs/tensorflow/@tests/stubtest_allowlist.txt index 437a99630..adb4f38f5 100644 --- a/stubs/tensorflow/@tests/stubtest_allowlist.txt +++ b/stubs/tensorflow/@tests/stubtest_allowlist.txt @@ -91,6 +91,7 @@ tensorflow.io.SparseFeature.__new__ # Metaclass inconsistency. The runtime metaclass is defined from c++ extension and is undocumented. tensorflow.io.TFRecordWriter +tensorflow.experimental.dtensor.Mesh # stubtest does not pass for protobuf generated stubs. tensorflow.train.Example.* diff --git a/stubs/tensorflow/tensorflow/_aliases.pyi b/stubs/tensorflow/tensorflow/_aliases.pyi index 7b92a6fa2..40042b807 100644 --- a/stubs/tensorflow/tensorflow/_aliases.pyi +++ b/stubs/tensorflow/tensorflow/_aliases.pyi @@ -29,6 +29,7 @@ KerasSerializable: TypeAlias = KerasSerializable1 | KerasSerializable2 Slice: TypeAlias = int | slice | None FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence] +IntDataSequence: TypeAlias = Sequence[int] | Sequence[IntDataSequence] StrDataSequence: TypeAlias = Sequence[str] | Sequence[StrDataSequence] ScalarTensorCompatible: TypeAlias = tf.Tensor | str | float | np.ndarray[Any, Any] | np.number[Any] @@ -52,4 +53,6 @@ ContainerInputSpec: TypeAlias = ContainerGeneric[InputSpec] AnyArray: TypeAlias = npt.NDArray[Any] FloatArray: TypeAlias = npt.NDArray[np.float_ | np.float16 | np.float32 | np.float64] -IntArray: TypeAlias = npt.NDArray[np.int_ | np.uint8 | np.int32 | np.int64] +UIntArray: TypeAlias = npt.NDArray[np.uint | np.uint8 | np.uint16 | np.uint32 | np.uint64] +SignedIntArray: TypeAlias = npt.NDArray[np.int_ | np.int8 | np.int16 | np.int32 | np.int64] +IntArray: TypeAlias = UIntArray | SignedIntArray diff --git a/stubs/tensorflow/tensorflow/experimental/dtensor.pyi b/stubs/tensorflow/tensorflow/experimental/dtensor.pyi new file mode 100644 index 000000000..ffb22d344 --- /dev/null +++ b/stubs/tensorflow/tensorflow/experimental/dtensor.pyi @@ -0,0 +1,17 @@ +from _typeshed import Incomplete + +from tensorflow._aliases import IntArray, IntDataSequence + +class Mesh: + def __init__( + self, + dim_names: list[str], + global_device_ids: IntArray | IntDataSequence, + local_device_ids: list[int], + local_devices: list[Incomplete | str], + mesh_name: str = "", + global_devices: list[Incomplete | str] | None = None, + use_xla_spmd: bool = False, + ) -> None: ... + +def __getattr__(name: str) -> Incomplete: ... diff --git a/stubs/tensorflow/tensorflow/summary.pyi b/stubs/tensorflow/tensorflow/summary.pyi new file mode 100644 index 000000000..e6bc8bdb7 --- /dev/null +++ b/stubs/tensorflow/tensorflow/summary.pyi @@ -0,0 +1,60 @@ +import abc +from _typeshed import Incomplete +from collections.abc import Callable, Iterator +from contextlib import AbstractContextManager, contextmanager +from typing import Literal +from typing_extensions import Self + +import tensorflow as tf +from tensorflow._aliases import FloatArray, IntArray +from tensorflow.experimental.dtensor import Mesh + +class SummaryWriter(metaclass=abc.ABCMeta): + def as_default(self, step: int | None = None) -> AbstractContextManager[Self]: ... + def close(self) -> None: ... + def flush(self) -> None: ... + def init(self) -> None: ... + def set_as_default(self, step: int | None = None) -> None: ... + +def audio( + name: str, + data: tf.Tensor, + sample_rate: int | tf.Tensor, + step: int | tf.Tensor | None = None, + max_outputs: int | tf.Tensor | None = 3, + encoding: Literal["wav"] | None = None, + description: str | None = None, +) -> bool: ... +def create_file_writer( + logdir: str, + max_queue: int | None = None, + flush_millis: int | None = None, + filename_suffix: str | None = None, + name: str | None = None, + experimental_trackable: bool = False, + experimental_mesh: Mesh | None = None, +) -> SummaryWriter: ... +def create_noop_writer() -> SummaryWriter: ... +def flush(writer: SummaryWriter | None = None, name: str | None = None) -> tf.Operation: ... +def graph(graph_data: tf.Graph | tf.compat.v1.GraphDef) -> bool: ... +def histogram( + name: str, data: tf.Tensor, step: int | None = None, buckets: int | None = None, description: str | None = None +) -> bool: ... +def image( + name: str, + data: tf.Tensor | FloatArray | IntArray, + step: int | tf.Tensor | None = None, + max_outputs: int | None = 3, + description: str | None = None, +) -> bool: ... +@contextmanager +def record_if(condition: bool | tf.Tensor | Callable[[], bool]) -> Iterator[None]: ... +def scalar(name: str, data: float | tf.Tensor, step: int | tf.Tensor | None = None, description: str | None = None) -> bool: ... +def should_record_summaries() -> bool: ... +def text(name: str, data: str | tf.Tensor, step: int | tf.Tensor | None = None, description: str | None = None) -> bool: ... +def trace_export(name: str, step: int | tf.Tensor | None = None, profiler_outdir: str | None = None) -> None: ... +def trace_off() -> None: ... +def trace_on(graph: bool = True, profiler: bool = False) -> None: ... +def write( + tag: str, tensor: tf.Tensor, step: int | tf.Tensor | None = None, metadata: Incomplete | None = None, name: str | None = None +) -> bool: ...