mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-10 05:51:52 +08:00
tensorflow: add some tensorflow functions (#11326)
This commit is contained in:
@@ -412,3 +412,30 @@ def convert_to_tensor(
|
||||
dtype_hint: _DTypeLike | None = None,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def expand_dims(input: _TensorCompatible, axis: int, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def expand_dims(input: RaggedTensor, axis: int, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def concat(values: _TensorCompatible, axis: int, name: str | None = "concat") -> Tensor: ...
|
||||
@overload
|
||||
def concat(values: Sequence[RaggedTensor], axis: int, name: str | None = "concat") -> RaggedTensor: ...
|
||||
@overload
|
||||
def squeeze(
|
||||
input: _TensorCompatible, axis: int | tuple[int, ...] | list[int] | None = None, name: str | None = None
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def squeeze(input: RaggedTensor, axis: int | tuple[int, ...] | list[int], name: str | None = None) -> RaggedTensor: ...
|
||||
def tensor_scatter_nd_update(
|
||||
tensor: _TensorCompatible, indices: _TensorCompatible, updates: _TensorCompatible, name: str | None = None
|
||||
) -> Tensor: ...
|
||||
def constant(
|
||||
value: _TensorCompatible, dtype: _DTypeLike | None = None, shape: _ShapeLike | None = None, name: str | None = "Const"
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def cast(x: _TensorCompatible, dtype: _DTypeLike, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def cast(x: SparseTensor, dtype: _DTypeLike, name: str | None = None) -> SparseTensor: ...
|
||||
@overload
|
||||
def cast(x: RaggedTensor, dtype: _DTypeLike, name: str | None = None) -> RaggedTensor: ...
|
||||
def reshape(tensor: _TensorCompatible, shape: _ShapeLike | Tensor, name: str | None = None) -> Tensor: ...
|
||||
|
||||
Reference in New Issue
Block a user