mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-07 12:44:28 +08:00
tensorflow: add partial tf.nn module. (#11388)
Some derived from: https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/summary.pyi https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/nn.pyi
This commit is contained in:
@@ -104,3 +104,6 @@ tensorflow.train.FloatList.*
|
||||
tensorflow.train.Int64List.*
|
||||
tensorflow.train.ClusterDef.*
|
||||
tensorflow.train.ServerDef.*
|
||||
|
||||
# sigmoid_cross_entropy_with_logits has default values (None), however those values are not valid.
|
||||
tensorflow.nn.sigmoid_cross_entropy_with_logits
|
||||
|
||||
@@ -44,6 +44,10 @@ TensorCompatible: TypeAlias = ScalarTensorCompatible | Sequence[TensorCompatible
|
||||
# broader type than TensorCompatible and not all operations will support broader version. If unsure,
|
||||
# use TensorCompatible instead.
|
||||
SparseTensorCompatible: TypeAlias = TensorCompatible | tf.SparseTensor
|
||||
# TensorFlow tries to convert anything passed as input. Meaning that even if, for example, only a Tensor of int32
|
||||
# is allowed, a numpy array of strings that can be converted to int32 will work. Therefore having anything more specific
|
||||
# then AnyArray might cause false positives, while AnyArray might cause false negatives.
|
||||
TensorOrArray: TypeAlias = tf.Tensor | AnyArray
|
||||
|
||||
ShapeLike: TypeAlias = tf.TensorShape | Iterable[ScalarTensorCompatible | None] | int | tf.Tensor
|
||||
DTypeLike: TypeAlias = tf.DType | str | np.dtype[Any] | int
|
||||
|
||||
195
stubs/tensorflow/tensorflow/nn.pyi
Normal file
195
stubs/tensorflow/tensorflow/nn.pyi
Normal file
@@ -0,0 +1,195 @@
|
||||
from _typeshed import Incomplete
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
from tensorflow import RaggedTensor, Tensor
|
||||
from tensorflow._aliases import ScalarTensorCompatible, TensorCompatible, TensorOrArray
|
||||
from tensorflow.math import l2_normalize as l2_normalize, sigmoid as sigmoid, tanh as tanh
|
||||
from tensorflow.sparse import SparseTensor
|
||||
|
||||
def atrous_conv2d(
|
||||
value: TensorOrArray, filters: TensorOrArray, rate: int, padding: Literal["VALID", "SAME"], name: str | None = None
|
||||
) -> Tensor: ...
|
||||
def atrous_conv2d_transpose(
|
||||
value: TensorOrArray,
|
||||
filters: TensorOrArray,
|
||||
output_shape: TensorOrArray,
|
||||
rate: int,
|
||||
padding: Literal["VALID", "SAME"],
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def avg_pool(
|
||||
input: TensorOrArray,
|
||||
ksize: int | Sequence[int],
|
||||
strides: int | Sequence[int],
|
||||
padding: Literal["VALID", "SAME"],
|
||||
data_format: Literal["NWC", "NCW", "NHWC", "NCHW", "NDHWC", "NCDHW"] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def avg_pool1d(
|
||||
input: TensorOrArray,
|
||||
ksize: int | Sequence[int],
|
||||
strides: int | Sequence[int],
|
||||
padding: Literal["VALID", "SAME"],
|
||||
data_format: Literal["NWC", "NCW"] = "NWC",
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def avg_pool2d(
|
||||
input: TensorOrArray,
|
||||
ksize: int | Sequence[int],
|
||||
strides: int | Sequence[int],
|
||||
padding: Literal["VALID", "SAME"],
|
||||
data_format: Literal["NHWC", "NCHW"] = "NHWC",
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def avg_pool3d(
|
||||
input: TensorOrArray,
|
||||
ksize: int | Sequence[int],
|
||||
strides: int | Sequence[int],
|
||||
padding: Literal["VALID", "SAME"],
|
||||
data_format: Literal["NDHWC", "NCDHW"] = "NDHWC",
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def batch_normalization(
|
||||
x: Tensor, mean: Tensor, variance: Tensor, offset: Tensor, scale: Tensor, variance_epsilon: float, name: str | None = None
|
||||
) -> Tensor: ...
|
||||
def bias_add(
|
||||
value: Tensor, bias: Tensor, data_format: Literal["N...C", "NC..."] | None = None, name: str | None = None
|
||||
) -> Tensor: ...
|
||||
def collapse_repeated(labels: Tensor, seq_length: Tensor, name: str | None = None) -> tuple[Tensor, Tensor]: ...
|
||||
def compute_accidental_hits(
|
||||
true_classes: Tensor, sampled_candidates: Tensor, num_true: int, seed: int | None = None, name: str | None = None
|
||||
) -> tuple[Tensor, Tensor, Tensor]: ...
|
||||
def compute_average_loss(
|
||||
per_example_loss: Tensor, sample_weight: Tensor | None = None, global_batch_size: int | None = None
|
||||
) -> Tensor: ...
|
||||
def conv1d(
|
||||
input: TensorOrArray,
|
||||
filters: TensorOrArray,
|
||||
stride: int | Sequence[int],
|
||||
padding: Literal["VALID", "SAME"],
|
||||
data_format: Literal["NWC", "NCW"] = "NWC",
|
||||
dilations: int | Sequence[int] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def conv1d_transpose(
|
||||
input: TensorOrArray,
|
||||
filters: TensorOrArray,
|
||||
output_shape: TensorOrArray,
|
||||
strides: int | Sequence[int],
|
||||
padding: Literal["VALID", "SAME"] = "SAME",
|
||||
data_format: Literal["NWC", "NCW"] = "NWC",
|
||||
dilations: int | Sequence[int] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def conv2d(
|
||||
input: TensorOrArray,
|
||||
filters: TensorOrArray,
|
||||
strides: int | Sequence[int],
|
||||
padding: Literal["VALID", "SAME"],
|
||||
data_format: Literal["NHWC", "NCHW"] = "NHWC",
|
||||
dilations: int | Sequence[int] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def conv2d_transpose(
|
||||
input: TensorOrArray,
|
||||
filters: TensorOrArray,
|
||||
output_shape: TensorOrArray,
|
||||
strides: int | Sequence[int],
|
||||
padding: Literal["VALID", "SAME"] = "SAME",
|
||||
data_format: Literal["NHWC", "NCHW"] = "NHWC",
|
||||
dilations: int | Sequence[int] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def conv3d(
|
||||
input: TensorOrArray,
|
||||
filters: TensorOrArray,
|
||||
strides: int | Sequence[int],
|
||||
padding: Literal["VALID", "SAME"],
|
||||
data_format: Literal["NDHWC", "NCDHW"] = "NDHWC",
|
||||
dilations: int | Sequence[int] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def conv3d_transpose(
|
||||
input: TensorOrArray,
|
||||
filters: TensorOrArray,
|
||||
output_shape: TensorOrArray,
|
||||
strides: int | Sequence[int],
|
||||
padding: Literal["VALID", "SAME"] = "SAME",
|
||||
data_format: Literal["NDHWC", "NCDHW"] = "NDHWC",
|
||||
dilations: int | Sequence[int] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def conv_transpose(
|
||||
input: TensorOrArray,
|
||||
filters: TensorOrArray,
|
||||
output_shape: TensorOrArray,
|
||||
strides: int | Sequence[int],
|
||||
padding: Literal["VALID", "SAME"] = "SAME",
|
||||
data_format: str | None = None,
|
||||
dilations: int | Sequence[int] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def convolution(
|
||||
input: TensorOrArray,
|
||||
filters: TensorOrArray,
|
||||
strides: int | Sequence[int] | None = None,
|
||||
padding: Literal["VALID", "SAME"] = "VALID",
|
||||
data_format: Literal["NC", "NWC", "NCW", "NHWC", "NCHW", "NDHWC", "NCDHW"] | None = None,
|
||||
dilations: int | Sequence[int] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def crelu(features: TensorOrArray, axis: int = -1, name: str | None = None) -> Tensor: ...
|
||||
def ctc_beam_search_decoder(
|
||||
inputs: TensorOrArray, sequence_length: TensorOrArray | Sequence[int], beam_width: int = 100, top_paths: int = 1
|
||||
) -> tuple[list[SparseTensor], Tensor]: ...
|
||||
def ctc_greedy_decoder(
|
||||
inputs: Tensor, sequence_length: Tensor | Sequence[int], merge_repeated: bool = True, blank_index: int | None = None
|
||||
) -> tuple[list[SparseTensor], Tensor]: ...
|
||||
def ctc_loss(
|
||||
labels: Tensor,
|
||||
logits: Tensor,
|
||||
label_length: Tensor,
|
||||
logit_length: Tensor,
|
||||
logits_time_major: bool = True,
|
||||
unique: int | None = None,
|
||||
blank_index: int | None = None,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def ctc_unique_labels(labels: Tensor, name: str | None = None) -> tuple[Tensor, Tensor]: ...
|
||||
@overload
|
||||
def embedding_lookup(
|
||||
params: TensorCompatible, ids: TensorCompatible, max_norm: float | None = None, name: str | None = None
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def embedding_lookup(
|
||||
params: TensorCompatible, ids: RaggedTensor, max_norm: float | None = None, name: str | None = None
|
||||
) -> RaggedTensor: ...
|
||||
def leaky_relu(features: TensorCompatible, alpha: float = 0.2, name: str | None = None) -> Tensor: ...
|
||||
def log_poisson_loss(
|
||||
targets: TensorCompatible, log_input: TensorCompatible, compute_full_loss: bool = False, name: str | None = None
|
||||
) -> Tensor: ...
|
||||
|
||||
# tf.nn.moments's shift is not used in the current implementation.
|
||||
def moments(
|
||||
x: TensorCompatible | RaggedTensor,
|
||||
axes: TensorCompatible,
|
||||
shift: None | Any = None,
|
||||
keepdims: bool = False,
|
||||
name: str | None = None,
|
||||
) -> tuple[Tensor, Tensor]: ...
|
||||
def relu(features: TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
def sigmoid_cross_entropy_with_logits(labels: TensorCompatible, logits: TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
def softmax(logits: TensorCompatible, axis: ScalarTensorCompatible | None = None, name: str | None = None) -> Tensor: ...
|
||||
def selu(features: TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
def safe_embedding_lookup_sparse(
|
||||
embedding_weights: Tensor | list[Tensor],
|
||||
sparse_ids: SparseTensor,
|
||||
sparse_weights: SparseTensor | None = None,
|
||||
combiner: str = "mean",
|
||||
default_id: ScalarTensorCompatible | None = None,
|
||||
max_norm: float | None = None,
|
||||
name: str | None = None,
|
||||
allow_fast_lookup: bool = False,
|
||||
) -> Tensor: ...
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
Reference in New Issue
Block a user