mirror of
https://github.com/davidhalter/typeshed.git
synced 2026-05-04 20:45:49 +08:00
tensorflow: add a few TensorFlow functions (#13364)
This commit is contained in:
@@ -6,7 +6,7 @@ from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, NoReturn, TypeVar, overload
|
||||
from typing import Any, Generic, Literal, NoReturn, TypeVar, overload
|
||||
from typing_extensions import ParamSpec, Self
|
||||
|
||||
from google.protobuf.message import Message
|
||||
@@ -20,7 +20,17 @@ from tensorflow import (
|
||||
math as math,
|
||||
types as types,
|
||||
)
|
||||
from tensorflow._aliases import AnyArray, DTypeLike, ShapeLike, Slice, TensorCompatible
|
||||
from tensorflow._aliases import (
|
||||
AnyArray,
|
||||
DTypeLike,
|
||||
IntArray,
|
||||
ScalarTensorCompatible,
|
||||
ShapeLike,
|
||||
Slice,
|
||||
SparseTensorCompatible,
|
||||
TensorCompatible,
|
||||
UIntTensorCompatible,
|
||||
)
|
||||
from tensorflow.autodiff import GradientTape as GradientTape
|
||||
from tensorflow.core.protobuf import struct_pb2
|
||||
from tensorflow.dtypes import *
|
||||
@@ -56,6 +66,7 @@ from tensorflow.math import (
|
||||
reduce_min as reduce_min,
|
||||
reduce_prod as reduce_prod,
|
||||
reduce_sum as reduce_sum,
|
||||
round as round,
|
||||
sigmoid as sigmoid,
|
||||
sign as sign,
|
||||
sin as sin,
|
||||
@@ -403,4 +414,22 @@ def ones_like(
|
||||
input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
|
||||
) -> RaggedTensor: ...
|
||||
def reshape(tensor: TensorCompatible, shape: ShapeLike | Tensor, name: str | None = None) -> Tensor: ...
|
||||
def pad(
|
||||
tensor: TensorCompatible,
|
||||
paddings: Tensor | IntArray | Iterable[Iterable[int]],
|
||||
mode: Literal["CONSTANT", "constant", "REFLECT", "reflect", "SYMMETRIC", "symmectric"] = "CONSTANT",
|
||||
constant_values: ScalarTensorCompatible = 0,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def shape(input: SparseTensorCompatible, out_type: DTypeLike | None = None, name: str | None = None) -> Tensor: ...
|
||||
def where(
|
||||
condition: TensorCompatible, x: TensorCompatible | None = None, y: TensorCompatible | None = None, name: str | None = None
|
||||
) -> Tensor: ...
|
||||
def gather_nd(
|
||||
params: TensorCompatible,
|
||||
indices: UIntTensorCompatible,
|
||||
batch_dims: UIntTensorCompatible = 0,
|
||||
name: str | None = None,
|
||||
bad_indices_policy: Literal["", "DEFAULT", "ERROR", "IGNORE"] = "",
|
||||
) -> Tensor: ...
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
|
||||
@@ -219,6 +219,12 @@ def square(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
def softplus(features: TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def softplus(features: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def round(x: TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def round(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
|
||||
@overload
|
||||
def round(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
|
||||
# Depending on the method axis is either a rank 0 tensor or a rank 0/1 tensor.
|
||||
def reduce_mean(
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
from tensorflow import Tensor
|
||||
from tensorflow._aliases import DTypeLike, TensorCompatible
|
||||
|
||||
def hamming_window(
|
||||
window_length: TensorCompatible, periodic: bool | TensorCompatible = True, dtype: DTypeLike = ..., name: str | None = None
|
||||
) -> Tensor: ...
|
||||
Reference in New Issue
Block a user