Tensorflow: Add more stubs (#9560)

Co-authored-by: Mehdi Drissi <mdrissi@snapchat.com>
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Mehdi Drissi
2023-01-31 22:12:41 -08:00
committed by GitHub
parent 1cc8080c7e
commit d755da86dd
2 changed files with 334 additions and 3 deletions

View File

@@ -8,10 +8,49 @@ from typing import Any, NoReturn, overload
from typing_extensions import TypeAlias
import numpy
# Explicit import of DType is covered by the wildcard, but
# is necessary to avoid a crash in pytype.
from tensorflow.dtypes import *
from tensorflow.dtypes import DType as DType
# Most tf.math functions are exported as tf, but sadly not all are.
from tensorflow.math import abs as abs
from tensorflow.math import (
abs as abs,
add as add,
add_n as add_n,
argmax as argmax,
argmin as argmin,
cos as cos,
cosh as cosh,
divide as divide,
equal as equal,
greater as greater,
greater_equal as greater_equal,
less as less,
less_equal as less_equal,
logical_and as logical_and,
logical_not as logical_not,
logical_or as logical_or,
maximum as maximum,
minimum as minimum,
multiply as multiply,
not_equal as not_equal,
pow as pow,
reduce_max as reduce_max,
reduce_mean as reduce_mean,
reduce_min as reduce_min,
reduce_prod as reduce_prod,
reduce_sum as reduce_sum,
sigmoid as sigmoid,
sign as sign,
sin as sin,
sinh as sinh,
sqrt as sqrt,
square as square,
subtract as subtract,
tanh as tanh,
)
from tensorflow.sparse import SparseTensor
# Tensors ideally should be a generic type, but properly typing data type/shape
@@ -191,4 +230,27 @@ class Graph:
def get_name_scope(self) -> str: ...
def __getattr__(self, name: str) -> Incomplete: ...
class IndexedSlices(metaclass=ABCMeta):
def __init__(self, values: Tensor, indices: Tensor, dense_shape: None | Tensor = None) -> None: ...
@property
def values(self) -> Tensor: ...
@property
def indices(self) -> Tensor: ...
@property
def dense_shape(self) -> None | Tensor: ...
@property
def shape(self) -> TensorShape: ...
@property
def dtype(self) -> DType: ...
@property
def name(self) -> str: ...
@property
def op(self) -> Operation: ...
@property
def graph(self) -> Graph: ...
@property
def device(self) -> str: ...
def __neg__(self) -> IndexedSlices: ...
def consumers(self) -> list[Operation]: ...
def __getattr__(name: str) -> Incomplete: ...

View File

@@ -1,13 +1,282 @@
from _typeshed import Incomplete
from typing import overload
from collections.abc import Iterable
from typing import TypeVar, overload
from typing_extensions import TypeAlias
from tensorflow import RaggedTensor, Tensor, _TensorCompatible
from tensorflow import IndexedSlices, RaggedTensor, Tensor, _DTypeLike, _ShapeLike, _TensorCompatible
from tensorflow.sparse import SparseTensor
_TensorCompatibleT = TypeVar("_TensorCompatibleT", bound=_TensorCompatible)
_SparseTensorCompatible: TypeAlias = _TensorCompatible | SparseTensor
# Most operations support RaggedTensor. Documentation for them is here,
# https://www.tensorflow.org/api_docs/python/tf/ragged.
# Most operations do not support SparseTensor. Operations often don't document
# whether they support SparseTensor and it is best to test them manually. Typically
# if an operation outputs non-zero value for a zero input, it will not support
# SparseTensors. Binary operations with ragged tensors usually only work
# if both operands are ragged.
@overload
def abs(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def abs(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
@overload
def abs(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def sin(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def sin(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def cos(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def cos(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def exp(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def exp(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def sinh(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def sinh(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def cosh(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def cosh(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def tanh(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def tanh(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
@overload
def tanh(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def expm1(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def expm1(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def log(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def log(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def log1p(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def log1p(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def negative(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def negative(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
@overload
def negative(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def sigmoid(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def sigmoid(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
@overload
def add(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def add(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def add_n(inputs: Iterable[_TensorCompatible | IndexedSlices], name: str | None = None) -> Tensor: ...
@overload
def add_n(inputs: Iterable[RaggedTensor], name: str | None = None) -> RaggedTensor: ...
@overload
def subtract(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def subtract(x: _TensorCompatible | RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def subtract(
x: _TensorCompatible | RaggedTensor, y: _TensorCompatible | RaggedTensor, name: str | None = None
) -> Tensor | RaggedTensor: ...
@overload
def multiply(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def multiply(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def multiply_no_nan(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def multiply_no_nan(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def divide(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def divide(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def divide_no_nan(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def divide_no_nan(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def floormod(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def floormod(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def ceil(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def ceil(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def floor(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def floor(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
# Uses isinstance on list/tuple so other Sequence types are not supported. The TypeVar is to
# behave covariantly.
def accumulate_n(
inputs: list[_TensorCompatibleT] | tuple[_TensorCompatibleT, ...],
shape: _ShapeLike | None = None,
tensor_dtype: _DTypeLike | None = None,
name: str | None = None,
) -> Tensor: ...
@overload
def pow(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def pow(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def reciprocal(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def reciprocal(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def is_nan(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def is_nan(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def minimum(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def minimum(x: RaggedTensor, y: _TensorCompatible | RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def minimum(x: _TensorCompatible | RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def maximum(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def maximum(x: RaggedTensor, y: _TensorCompatible | RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def maximum(x: _TensorCompatible | RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def logical_not(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def logical_not(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def logical_and(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def logical_and(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def logical_or(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def logical_or(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def logical_xor(x: _TensorCompatible, y: _TensorCompatible, name: str | None = "LogicalXor") -> Tensor: ...
@overload
def logical_xor(x: RaggedTensor, y: RaggedTensor, name: str | None = "LogicalXor") -> RaggedTensor: ...
@overload
def equal(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def equal(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
@overload
def not_equal(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def not_equal(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
@overload
def greater(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def greater(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
@overload
def greater_equal(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def greater_equal(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
@overload
def less(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def less(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
@overload
def less_equal(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def less_equal(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
def segment_sum(data: _TensorCompatible, segment_ids: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def sign(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def sign(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
@overload
def sign(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def sqrt(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def sqrt(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
@overload
def sqrt(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def rsqrt(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def rsqrt(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def square(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def square(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
@overload
def square(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
@overload
def softplus(features: _TensorCompatible, name: str | None = None) -> Tensor: ...
@overload
def softplus(features: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
# Depending on the method axis is either a rank 0 tensor or a rank 0/1 tensor.
def reduce_mean(
input_tensor: _TensorCompatible | RaggedTensor,
axis: _TensorCompatible | None = None,
keepdims: bool = False,
name: str | None = None,
) -> Tensor: ...
def reduce_sum(
input_tensor: _TensorCompatible | RaggedTensor,
axis: _TensorCompatible | None = None,
keepdims: bool = False,
name: str | None = None,
) -> Tensor: ...
def reduce_max(
input_tensor: _TensorCompatible | RaggedTensor,
axis: _TensorCompatible | None = None,
keepdims: bool = False,
name: str | None = None,
) -> Tensor: ...
def reduce_min(
input_tensor: _TensorCompatible | RaggedTensor,
axis: _TensorCompatible | None = None,
keepdims: bool = False,
name: str | None = None,
) -> Tensor: ...
def reduce_prod(
input_tensor: _TensorCompatible | RaggedTensor,
axis: _TensorCompatible | None = None,
keepdims: bool = False,
name: str | None = None,
) -> Tensor: ...
def reduce_std(
input_tensor: _TensorCompatible | RaggedTensor,
axis: _TensorCompatible | None = None,
keepdims: bool = False,
name: str | None = None,
) -> Tensor: ...
def argmax(
input: _TensorCompatible, axis: _TensorCompatible | None = None, output_type: _DTypeLike = ..., name: str | None = None
) -> Tensor: ...
def argmin(
input: _TensorCompatible, axis: _TensorCompatible | None = None, output_type: _DTypeLike = ..., name: str | None = None
) -> Tensor: ...
# Only for bool tensors.
def reduce_any(
input_tensor: _TensorCompatible | RaggedTensor,
axis: _TensorCompatible | None = None,
keepdims: bool = False,
name: str | None = None,
) -> Tensor: ...
def reduce_all(
input_tensor: _TensorCompatible | RaggedTensor,
axis: _TensorCompatible | None = None,
keepdims: bool = False,
name: str | None = None,
) -> Tensor: ...
def count_nonzero(
input: _SparseTensorCompatible,
axis: _TensorCompatible | None = None,
keepdims: bool | None = None,
dtype: _DTypeLike = ...,
name: str | None = None,
) -> Tensor: ...
def __getattr__(name: str) -> Incomplete: ...