From 6087745bcf41c73c02f9c9eabee37ccbc983d9eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ho=C3=ABl=20Bagard?= <34478245+hoel-bagard@users.noreply.github.com> Date: Sat, 17 Feb 2024 14:52:05 +0900 Subject: [PATCH] `tensorflow`: add `tf.ones`, `tf.zeros`, `tf.zeros_like` and `tf.ones_like` functions (#11368) --- stubs/tensorflow/tensorflow/__init__.pyi | 19 +++++++++++++++++++ .../tensorflow/experimental/dtensor.pyi | 2 ++ 2 files changed, 21 insertions(+) diff --git a/stubs/tensorflow/tensorflow/__init__.pyi b/stubs/tensorflow/tensorflow/__init__.pyi index dfe8b02c4..cb458d231 100644 --- a/stubs/tensorflow/tensorflow/__init__.pyi +++ b/stubs/tensorflow/tensorflow/__init__.pyi @@ -36,6 +36,7 @@ from tensorflow.core.protobuf import struct_pb2 # is necessary to avoid a crash in pytype. from tensorflow.dtypes import * from tensorflow.dtypes import DType as DType +from tensorflow.experimental.dtensor import Layout from tensorflow.keras import losses as losses # Most tf.math functions are exported as tf, but sadly not all are. @@ -438,5 +439,23 @@ def cast(x: TensorCompatible, dtype: DTypeLike, name: str | None = None) -> Tens def cast(x: SparseTensor, dtype: DTypeLike, name: str | None = None) -> SparseTensor: ... @overload def cast(x: RaggedTensor, dtype: DTypeLike, name: str | None = None) -> RaggedTensor: ... +def zeros(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ... +def ones(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ... +@overload +def zeros_like( + input: TensorCompatible | IndexedSlices, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> Tensor: ... +@overload +def zeros_like( + input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> RaggedTensor: ... +@overload +def ones_like( + input: TensorCompatible, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> Tensor: ... +@overload +def ones_like( + input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> RaggedTensor: ... def reshape(tensor: TensorCompatible, shape: ShapeLike | Tensor, name: str | None = None) -> Tensor: ... def __getattr__(name: str) -> Incomplete: ... diff --git a/stubs/tensorflow/tensorflow/experimental/dtensor.pyi b/stubs/tensorflow/tensorflow/experimental/dtensor.pyi index ffb22d344..178d1211f 100644 --- a/stubs/tensorflow/tensorflow/experimental/dtensor.pyi +++ b/stubs/tensorflow/tensorflow/experimental/dtensor.pyi @@ -2,6 +2,8 @@ from _typeshed import Incomplete from tensorflow._aliases import IntArray, IntDataSequence +Layout = Incomplete + class Mesh: def __init__( self,