diff --git a/stubs/tensorflow/tensorflow/__init__.pyi b/stubs/tensorflow/tensorflow/__init__.pyi index dfe8b02c4..cb458d231 100644 --- a/stubs/tensorflow/tensorflow/__init__.pyi +++ b/stubs/tensorflow/tensorflow/__init__.pyi @@ -36,6 +36,7 @@ from tensorflow.core.protobuf import struct_pb2 # is necessary to avoid a crash in pytype. from tensorflow.dtypes import * from tensorflow.dtypes import DType as DType +from tensorflow.experimental.dtensor import Layout from tensorflow.keras import losses as losses # Most tf.math functions are exported as tf, but sadly not all are. @@ -438,5 +439,23 @@ def cast(x: TensorCompatible, dtype: DTypeLike, name: str | None = None) -> Tens def cast(x: SparseTensor, dtype: DTypeLike, name: str | None = None) -> SparseTensor: ... @overload def cast(x: RaggedTensor, dtype: DTypeLike, name: str | None = None) -> RaggedTensor: ... +def zeros(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ... +def ones(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ... +@overload +def zeros_like( + input: TensorCompatible | IndexedSlices, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> Tensor: ... +@overload +def zeros_like( + input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> RaggedTensor: ... +@overload +def ones_like( + input: TensorCompatible, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> Tensor: ... +@overload +def ones_like( + input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> RaggedTensor: ... def reshape(tensor: TensorCompatible, shape: ShapeLike | Tensor, name: str | None = None) -> Tensor: ... def __getattr__(name: str) -> Incomplete: ... diff --git a/stubs/tensorflow/tensorflow/experimental/dtensor.pyi b/stubs/tensorflow/tensorflow/experimental/dtensor.pyi index ffb22d344..178d1211f 100644 --- a/stubs/tensorflow/tensorflow/experimental/dtensor.pyi +++ b/stubs/tensorflow/tensorflow/experimental/dtensor.pyi @@ -2,6 +2,8 @@ from _typeshed import Incomplete from tensorflow._aliases import IntArray, IntDataSequence +Layout = Incomplete + class Mesh: def __init__( self,