tensorflow: add tf.ones, tf.zeros, tf.zeros_like and tf.ones_like functions (#11368)

This commit is contained in:
Hoël Bagard
2024-02-17 14:52:05 +09:00
committed by GitHub
parent d93ee88ba5
commit 6087745bcf
2 changed files with 21 additions and 0 deletions

View File

@@ -36,6 +36,7 @@ from tensorflow.core.protobuf import struct_pb2
# is necessary to avoid a crash in pytype.
from tensorflow.dtypes import *
from tensorflow.dtypes import DType as DType
from tensorflow.experimental.dtensor import Layout
from tensorflow.keras import losses as losses
# Most tf.math functions are exported as tf, but sadly not all are.
@@ -438,5 +439,23 @@ def cast(x: TensorCompatible, dtype: DTypeLike, name: str | None = None) -> Tens
def cast(x: SparseTensor, dtype: DTypeLike, name: str | None = None) -> SparseTensor: ...
@overload
def cast(x: RaggedTensor, dtype: DTypeLike, name: str | None = None) -> RaggedTensor: ...
def zeros(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ...
def ones(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ...
@overload
def zeros_like(
input: TensorCompatible | IndexedSlices, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
) -> Tensor: ...
@overload
def zeros_like(
input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
) -> RaggedTensor: ...
@overload
def ones_like(
input: TensorCompatible, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
) -> Tensor: ...
@overload
def ones_like(
input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
) -> RaggedTensor: ...
def reshape(tensor: TensorCompatible, shape: ShapeLike | Tensor, name: str | None = None) -> Tensor: ...
def __getattr__(name: str) -> Incomplete: ...

View File

@@ -2,6 +2,8 @@ from _typeshed import Incomplete
from tensorflow._aliases import IntArray, IntDataSequence
Layout = Incomplete
class Mesh:
def __init__(
self,