tensorflow: add some tensorflow functions (#11326)

This commit is contained in:
Hoël Bagard
2024-01-31 12:57:25 +09:00
committed by GitHub
parent 68ae493297
commit 98ba275d5e

View File

@@ -412,3 +412,30 @@ def convert_to_tensor(
dtype_hint: _DTypeLike | None = None,
name: str | None = None,
) -> Tensor: ...
@overload
def expand_dims(input: _TensorCompatible, axis: int, name: str | None = None) -> Tensor: ...
@overload
def expand_dims(input: RaggedTensor, axis: int, name: str | None = None) -> RaggedTensor: ...
@overload
def concat(values: _TensorCompatible, axis: int, name: str | None = "concat") -> Tensor: ...
@overload
def concat(values: Sequence[RaggedTensor], axis: int, name: str | None = "concat") -> RaggedTensor: ...
@overload
def squeeze(
input: _TensorCompatible, axis: int | tuple[int, ...] | list[int] | None = None, name: str | None = None
) -> Tensor: ...
@overload
def squeeze(input: RaggedTensor, axis: int | tuple[int, ...] | list[int], name: str | None = None) -> RaggedTensor: ...
def tensor_scatter_nd_update(
tensor: _TensorCompatible, indices: _TensorCompatible, updates: _TensorCompatible, name: str | None = None
) -> Tensor: ...
def constant(
value: _TensorCompatible, dtype: _DTypeLike | None = None, shape: _ShapeLike | None = None, name: str | None = "Const"
) -> Tensor: ...
@overload
def cast(x: _TensorCompatible, dtype: _DTypeLike, name: str | None = None) -> Tensor: ...
@overload
def cast(x: SparseTensor, dtype: _DTypeLike, name: str | None = None) -> SparseTensor: ...
@overload
def cast(x: RaggedTensor, dtype: _DTypeLike, name: str | None = None) -> RaggedTensor: ...
def reshape(tensor: _TensorCompatible, shape: _ShapeLike | Tensor, name: str | None = None) -> Tensor: ...