From 98ba275d5e4e338aafa5f0acaad3d944fd42f28e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ho=C3=ABl=20Bagard?= <34478245+hoel-bagard@users.noreply.github.com> Date: Wed, 31 Jan 2024 12:57:25 +0900 Subject: [PATCH] `tensorflow`: add some tensorflow functions (#11326) --- stubs/tensorflow/tensorflow/__init__.pyi | 27 ++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/stubs/tensorflow/tensorflow/__init__.pyi b/stubs/tensorflow/tensorflow/__init__.pyi index 9efff3cf3..d773d5c48 100644 --- a/stubs/tensorflow/tensorflow/__init__.pyi +++ b/stubs/tensorflow/tensorflow/__init__.pyi @@ -412,3 +412,30 @@ def convert_to_tensor( dtype_hint: _DTypeLike | None = None, name: str | None = None, ) -> Tensor: ... +@overload +def expand_dims(input: _TensorCompatible, axis: int, name: str | None = None) -> Tensor: ... +@overload +def expand_dims(input: RaggedTensor, axis: int, name: str | None = None) -> RaggedTensor: ... +@overload +def concat(values: _TensorCompatible, axis: int, name: str | None = "concat") -> Tensor: ... +@overload +def concat(values: Sequence[RaggedTensor], axis: int, name: str | None = "concat") -> RaggedTensor: ... +@overload +def squeeze( + input: _TensorCompatible, axis: int | tuple[int, ...] | list[int] | None = None, name: str | None = None +) -> Tensor: ... +@overload +def squeeze(input: RaggedTensor, axis: int | tuple[int, ...] | list[int], name: str | None = None) -> RaggedTensor: ... +def tensor_scatter_nd_update( + tensor: _TensorCompatible, indices: _TensorCompatible, updates: _TensorCompatible, name: str | None = None +) -> Tensor: ... +def constant( + value: _TensorCompatible, dtype: _DTypeLike | None = None, shape: _ShapeLike | None = None, name: str | None = "Const" +) -> Tensor: ... +@overload +def cast(x: _TensorCompatible, dtype: _DTypeLike, name: str | None = None) -> Tensor: ... +@overload +def cast(x: SparseTensor, dtype: _DTypeLike, name: str | None = None) -> SparseTensor: ... +@overload +def cast(x: RaggedTensor, dtype: _DTypeLike, name: str | None = None) -> RaggedTensor: ... +def reshape(tensor: _TensorCompatible, shape: _ShapeLike | Tensor, name: str | None = None) -> Tensor: ...