mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-09 13:34:58 +08:00
Tensorflow: Add more stubs (#9560)
Co-authored-by: Mehdi Drissi <mdrissi@snapchat.com> Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
@@ -8,10 +8,49 @@ from typing import Any, NoReturn, overload
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import numpy
|
||||
|
||||
# Explicit import of DType is covered by the wildcard, but
|
||||
# is necessary to avoid a crash in pytype.
|
||||
from tensorflow.dtypes import *
|
||||
from tensorflow.dtypes import DType as DType
|
||||
|
||||
# Most tf.math functions are exported as tf, but sadly not all are.
|
||||
from tensorflow.math import abs as abs
|
||||
from tensorflow.math import (
|
||||
abs as abs,
|
||||
add as add,
|
||||
add_n as add_n,
|
||||
argmax as argmax,
|
||||
argmin as argmin,
|
||||
cos as cos,
|
||||
cosh as cosh,
|
||||
divide as divide,
|
||||
equal as equal,
|
||||
greater as greater,
|
||||
greater_equal as greater_equal,
|
||||
less as less,
|
||||
less_equal as less_equal,
|
||||
logical_and as logical_and,
|
||||
logical_not as logical_not,
|
||||
logical_or as logical_or,
|
||||
maximum as maximum,
|
||||
minimum as minimum,
|
||||
multiply as multiply,
|
||||
not_equal as not_equal,
|
||||
pow as pow,
|
||||
reduce_max as reduce_max,
|
||||
reduce_mean as reduce_mean,
|
||||
reduce_min as reduce_min,
|
||||
reduce_prod as reduce_prod,
|
||||
reduce_sum as reduce_sum,
|
||||
sigmoid as sigmoid,
|
||||
sign as sign,
|
||||
sin as sin,
|
||||
sinh as sinh,
|
||||
sqrt as sqrt,
|
||||
square as square,
|
||||
subtract as subtract,
|
||||
tanh as tanh,
|
||||
)
|
||||
from tensorflow.sparse import SparseTensor
|
||||
|
||||
# Tensors ideally should be a generic type, but properly typing data type/shape
|
||||
@@ -191,4 +230,27 @@ class Graph:
|
||||
def get_name_scope(self) -> str: ...
|
||||
def __getattr__(self, name: str) -> Incomplete: ...
|
||||
|
||||
class IndexedSlices(metaclass=ABCMeta):
|
||||
def __init__(self, values: Tensor, indices: Tensor, dense_shape: None | Tensor = None) -> None: ...
|
||||
@property
|
||||
def values(self) -> Tensor: ...
|
||||
@property
|
||||
def indices(self) -> Tensor: ...
|
||||
@property
|
||||
def dense_shape(self) -> None | Tensor: ...
|
||||
@property
|
||||
def shape(self) -> TensorShape: ...
|
||||
@property
|
||||
def dtype(self) -> DType: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def op(self) -> Operation: ...
|
||||
@property
|
||||
def graph(self) -> Graph: ...
|
||||
@property
|
||||
def device(self) -> str: ...
|
||||
def __neg__(self) -> IndexedSlices: ...
|
||||
def consumers(self) -> list[Operation]: ...
|
||||
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
|
||||
@@ -1,13 +1,282 @@
|
||||
from _typeshed import Incomplete
|
||||
from typing import overload
|
||||
from collections.abc import Iterable
|
||||
from typing import TypeVar, overload
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from tensorflow import RaggedTensor, Tensor, _TensorCompatible
|
||||
from tensorflow import IndexedSlices, RaggedTensor, Tensor, _DTypeLike, _ShapeLike, _TensorCompatible
|
||||
from tensorflow.sparse import SparseTensor
|
||||
|
||||
_TensorCompatibleT = TypeVar("_TensorCompatibleT", bound=_TensorCompatible)
|
||||
_SparseTensorCompatible: TypeAlias = _TensorCompatible | SparseTensor
|
||||
|
||||
# Most operations support RaggedTensor. Documentation for them is here,
|
||||
# https://www.tensorflow.org/api_docs/python/tf/ragged.
|
||||
# Most operations do not support SparseTensor. Operations often don't document
|
||||
# whether they support SparseTensor and it is best to test them manually. Typically
|
||||
# if an operation outputs non-zero value for a zero input, it will not support
|
||||
# SparseTensors. Binary operations with ragged tensors usually only work
|
||||
# if both operands are ragged.
|
||||
@overload
|
||||
def abs(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def abs(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
|
||||
@overload
|
||||
def abs(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def sin(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def sin(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def cos(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def cos(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def exp(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def exp(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def sinh(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def sinh(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def cosh(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def cosh(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def tanh(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def tanh(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
|
||||
@overload
|
||||
def tanh(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def expm1(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def expm1(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def log(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def log(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def log1p(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def log1p(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def negative(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def negative(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
|
||||
@overload
|
||||
def negative(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def sigmoid(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def sigmoid(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
|
||||
@overload
|
||||
def add(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def add(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def add_n(inputs: Iterable[_TensorCompatible | IndexedSlices], name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def add_n(inputs: Iterable[RaggedTensor], name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def subtract(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def subtract(x: _TensorCompatible | RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def subtract(
|
||||
x: _TensorCompatible | RaggedTensor, y: _TensorCompatible | RaggedTensor, name: str | None = None
|
||||
) -> Tensor | RaggedTensor: ...
|
||||
@overload
|
||||
def multiply(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def multiply(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def multiply_no_nan(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def multiply_no_nan(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def divide(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def divide(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def divide_no_nan(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def divide_no_nan(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def floormod(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def floormod(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def ceil(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def ceil(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def floor(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def floor(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
|
||||
# Uses isinstance on list/tuple so other Sequence types are not supported. The TypeVar is to
|
||||
# behave covariantly.
|
||||
def accumulate_n(
|
||||
inputs: list[_TensorCompatibleT] | tuple[_TensorCompatibleT, ...],
|
||||
shape: _ShapeLike | None = None,
|
||||
tensor_dtype: _DTypeLike | None = None,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def pow(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def pow(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def reciprocal(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def reciprocal(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def is_nan(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def is_nan(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def minimum(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def minimum(x: RaggedTensor, y: _TensorCompatible | RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def minimum(x: _TensorCompatible | RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def maximum(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def maximum(x: RaggedTensor, y: _TensorCompatible | RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def maximum(x: _TensorCompatible | RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def logical_not(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def logical_not(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def logical_and(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def logical_and(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def logical_or(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def logical_or(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def logical_xor(x: _TensorCompatible, y: _TensorCompatible, name: str | None = "LogicalXor") -> Tensor: ...
|
||||
@overload
|
||||
def logical_xor(x: RaggedTensor, y: RaggedTensor, name: str | None = "LogicalXor") -> RaggedTensor: ...
|
||||
@overload
|
||||
def equal(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def equal(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def not_equal(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def not_equal(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def greater(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def greater(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def greater_equal(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def greater_equal(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def less(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def less(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def less_equal(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def less_equal(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
|
||||
def segment_sum(data: _TensorCompatible, segment_ids: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def sign(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def sign(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
|
||||
@overload
|
||||
def sign(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def sqrt(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def sqrt(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
|
||||
@overload
|
||||
def sqrt(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def rsqrt(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def rsqrt(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def square(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def square(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
|
||||
@overload
|
||||
def square(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def softplus(features: _TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def softplus(features: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
|
||||
# Depending on the method axis is either a rank 0 tensor or a rank 0/1 tensor.
|
||||
def reduce_mean(
|
||||
input_tensor: _TensorCompatible | RaggedTensor,
|
||||
axis: _TensorCompatible | None = None,
|
||||
keepdims: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def reduce_sum(
|
||||
input_tensor: _TensorCompatible | RaggedTensor,
|
||||
axis: _TensorCompatible | None = None,
|
||||
keepdims: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def reduce_max(
|
||||
input_tensor: _TensorCompatible | RaggedTensor,
|
||||
axis: _TensorCompatible | None = None,
|
||||
keepdims: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def reduce_min(
|
||||
input_tensor: _TensorCompatible | RaggedTensor,
|
||||
axis: _TensorCompatible | None = None,
|
||||
keepdims: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def reduce_prod(
|
||||
input_tensor: _TensorCompatible | RaggedTensor,
|
||||
axis: _TensorCompatible | None = None,
|
||||
keepdims: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def reduce_std(
|
||||
input_tensor: _TensorCompatible | RaggedTensor,
|
||||
axis: _TensorCompatible | None = None,
|
||||
keepdims: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def argmax(
|
||||
input: _TensorCompatible, axis: _TensorCompatible | None = None, output_type: _DTypeLike = ..., name: str | None = None
|
||||
) -> Tensor: ...
|
||||
def argmin(
|
||||
input: _TensorCompatible, axis: _TensorCompatible | None = None, output_type: _DTypeLike = ..., name: str | None = None
|
||||
) -> Tensor: ...
|
||||
|
||||
# Only for bool tensors.
|
||||
def reduce_any(
|
||||
input_tensor: _TensorCompatible | RaggedTensor,
|
||||
axis: _TensorCompatible | None = None,
|
||||
keepdims: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def reduce_all(
|
||||
input_tensor: _TensorCompatible | RaggedTensor,
|
||||
axis: _TensorCompatible | None = None,
|
||||
keepdims: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def count_nonzero(
|
||||
input: _SparseTensorCompatible,
|
||||
axis: _TensorCompatible | None = None,
|
||||
keepdims: bool | None = None,
|
||||
dtype: _DTypeLike = ...,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
|
||||
Reference in New Issue
Block a user