tensorflow: add a few TensorFlow functions (#13364)

This commit is contained in:
Hoël Bagard
2025-02-26 23:07:07 +09:00
committed by GitHub
parent 84c78c6442
commit a1c185b0b2
3 changed files with 43 additions and 2 deletions
+31 -2
View File
@@ -6,7 +6,7 @@ from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
from contextlib import contextmanager
from enum import Enum
from types import TracebackType
from typing import Any, Generic, NoReturn, TypeVar, overload
from typing import Any, Generic, Literal, NoReturn, TypeVar, overload
from typing_extensions import ParamSpec, Self
from google.protobuf.message import Message
@@ -20,7 +20,17 @@ from tensorflow import (
math as math,
types as types,
)
from tensorflow._aliases import AnyArray, DTypeLike, ShapeLike, Slice, TensorCompatible
from tensorflow._aliases import (
AnyArray,
DTypeLike,
IntArray,
ScalarTensorCompatible,
ShapeLike,
Slice,
SparseTensorCompatible,
TensorCompatible,
UIntTensorCompatible,
)
from tensorflow.autodiff import GradientTape as GradientTape
from tensorflow.core.protobuf import struct_pb2
from tensorflow.dtypes import *
@@ -56,6 +66,7 @@ from tensorflow.math import (
reduce_min as reduce_min,
reduce_prod as reduce_prod,
reduce_sum as reduce_sum,
round as round,
sigmoid as sigmoid,
sign as sign,
sin as sin,
@@ -403,4 +414,22 @@ def ones_like(
input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
) -> RaggedTensor: ...
def reshape(tensor: TensorCompatible, shape: ShapeLike | Tensor, name: str | None = None) -> Tensor: ...
def pad(
tensor: TensorCompatible,
paddings: Tensor | IntArray | Iterable[Iterable[int]],
mode: Literal["CONSTANT", "constant", "REFLECT", "reflect", "SYMMETRIC", "symmectric"] = "CONSTANT",
constant_values: ScalarTensorCompatible = 0,
name: str | None = None,
) -> Tensor: ...
def shape(input: SparseTensorCompatible, out_type: DTypeLike | None = None, name: str | None = None) -> Tensor: ...
def where(
condition: TensorCompatible, x: TensorCompatible | None = None, y: TensorCompatible | None = None, name: str | None = None
) -> Tensor: ...
def gather_nd(
params: TensorCompatible,
indices: UIntTensorCompatible,
batch_dims: UIntTensorCompatible = 0,
name: str | None = None,
bad_indices_policy: Literal["", "DEFAULT", "ERROR", "IGNORE"] = "",
) -> Tensor: ...
def __getattr__(name: str) -> Incomplete: ...
+6
View File
@@ -219,6 +219,12 @@ def square(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
def softplus(features: TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def softplus(features: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def round(x: TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def round(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
@overload
def round(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
# Depending on the method axis is either a rank 0 tensor or a rank 0/1 tensor.
def reduce_mean(
+6
View File
@@ -0,0 +1,6 @@
from tensorflow import Tensor
from tensorflow._aliases import DTypeLike, TensorCompatible
def hamming_window(
window_length: TensorCompatible, periodic: bool | TensorCompatible = True, dtype: DTypeLike = ..., name: str | None = None
) -> Tensor: ...