From d755da86dd8518dc40c55d1b753797d35959fe4a Mon Sep 17 00:00:00 2001 From: Mehdi Drissi Date: Tue, 31 Jan 2023 22:12:41 -0800 Subject: [PATCH] Tensorflow: Add more stubs (#9560) Co-authored-by: Mehdi Drissi Co-authored-by: Alex Waygood --- stubs/tensorflow/tensorflow/__init__.pyi | 64 +++++- stubs/tensorflow/tensorflow/math.pyi | 273 ++++++++++++++++++++++- 2 files changed, 334 insertions(+), 3 deletions(-) diff --git a/stubs/tensorflow/tensorflow/__init__.pyi b/stubs/tensorflow/tensorflow/__init__.pyi index eae7b85a7..62cab55b7 100644 --- a/stubs/tensorflow/tensorflow/__init__.pyi +++ b/stubs/tensorflow/tensorflow/__init__.pyi @@ -8,10 +8,49 @@ from typing import Any, NoReturn, overload from typing_extensions import TypeAlias import numpy + +# Explicit import of DType is covered by the wildcard, but +# is necessary to avoid a crash in pytype. from tensorflow.dtypes import * +from tensorflow.dtypes import DType as DType # Most tf.math functions are exported as tf, but sadly not all are. -from tensorflow.math import abs as abs +from tensorflow.math import ( + abs as abs, + add as add, + add_n as add_n, + argmax as argmax, + argmin as argmin, + cos as cos, + cosh as cosh, + divide as divide, + equal as equal, + greater as greater, + greater_equal as greater_equal, + less as less, + less_equal as less_equal, + logical_and as logical_and, + logical_not as logical_not, + logical_or as logical_or, + maximum as maximum, + minimum as minimum, + multiply as multiply, + not_equal as not_equal, + pow as pow, + reduce_max as reduce_max, + reduce_mean as reduce_mean, + reduce_min as reduce_min, + reduce_prod as reduce_prod, + reduce_sum as reduce_sum, + sigmoid as sigmoid, + sign as sign, + sin as sin, + sinh as sinh, + sqrt as sqrt, + square as square, + subtract as subtract, + tanh as tanh, +) from tensorflow.sparse import SparseTensor # Tensors ideally should be a generic type, but properly typing data type/shape @@ -191,4 +230,27 @@ class Graph: def get_name_scope(self) -> str: ... def __getattr__(self, name: str) -> Incomplete: ... +class IndexedSlices(metaclass=ABCMeta): + def __init__(self, values: Tensor, indices: Tensor, dense_shape: None | Tensor = None) -> None: ... + @property + def values(self) -> Tensor: ... + @property + def indices(self) -> Tensor: ... + @property + def dense_shape(self) -> None | Tensor: ... + @property + def shape(self) -> TensorShape: ... + @property + def dtype(self) -> DType: ... + @property + def name(self) -> str: ... + @property + def op(self) -> Operation: ... + @property + def graph(self) -> Graph: ... + @property + def device(self) -> str: ... + def __neg__(self) -> IndexedSlices: ... + def consumers(self) -> list[Operation]: ... + def __getattr__(name: str) -> Incomplete: ... diff --git a/stubs/tensorflow/tensorflow/math.pyi b/stubs/tensorflow/tensorflow/math.pyi index f97802a7a..66fc425db 100644 --- a/stubs/tensorflow/tensorflow/math.pyi +++ b/stubs/tensorflow/tensorflow/math.pyi @@ -1,13 +1,282 @@ from _typeshed import Incomplete -from typing import overload +from collections.abc import Iterable +from typing import TypeVar, overload +from typing_extensions import TypeAlias -from tensorflow import RaggedTensor, Tensor, _TensorCompatible +from tensorflow import IndexedSlices, RaggedTensor, Tensor, _DTypeLike, _ShapeLike, _TensorCompatible from tensorflow.sparse import SparseTensor +_TensorCompatibleT = TypeVar("_TensorCompatibleT", bound=_TensorCompatible) +_SparseTensorCompatible: TypeAlias = _TensorCompatible | SparseTensor + +# Most operations support RaggedTensor. Documentation for them is here, +# https://www.tensorflow.org/api_docs/python/tf/ragged. +# Most operations do not support SparseTensor. Operations often don't document +# whether they support SparseTensor and it is best to test them manually. Typically +# if an operation outputs non-zero value for a zero input, it will not support +# SparseTensors. Binary operations with ragged tensors usually only work +# if both operands are ragged. @overload def abs(x: _TensorCompatible, name: str | None = None) -> Tensor: ... @overload def abs(x: SparseTensor, name: str | None = None) -> SparseTensor: ... @overload def abs(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def sin(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def sin(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def cos(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def cos(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def exp(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def exp(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def sinh(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def sinh(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def cosh(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def cosh(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def tanh(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def tanh(x: SparseTensor, name: str | None = None) -> SparseTensor: ... +@overload +def tanh(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def expm1(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def expm1(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def log(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def log(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def log1p(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def log1p(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def negative(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def negative(x: SparseTensor, name: str | None = None) -> SparseTensor: ... +@overload +def negative(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def sigmoid(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def sigmoid(x: SparseTensor, name: str | None = None) -> SparseTensor: ... +@overload +def add(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def add(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def add_n(inputs: Iterable[_TensorCompatible | IndexedSlices], name: str | None = None) -> Tensor: ... +@overload +def add_n(inputs: Iterable[RaggedTensor], name: str | None = None) -> RaggedTensor: ... +@overload +def subtract(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def subtract(x: _TensorCompatible | RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def subtract( + x: _TensorCompatible | RaggedTensor, y: _TensorCompatible | RaggedTensor, name: str | None = None +) -> Tensor | RaggedTensor: ... +@overload +def multiply(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def multiply(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def multiply_no_nan(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def multiply_no_nan(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def divide(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def divide(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def divide_no_nan(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def divide_no_nan(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def floormod(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def floormod(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def ceil(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def ceil(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def floor(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def floor(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... + +# Uses isinstance on list/tuple so other Sequence types are not supported. The TypeVar is to +# behave covariantly. +def accumulate_n( + inputs: list[_TensorCompatibleT] | tuple[_TensorCompatibleT, ...], + shape: _ShapeLike | None = None, + tensor_dtype: _DTypeLike | None = None, + name: str | None = None, +) -> Tensor: ... +@overload +def pow(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def pow(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def reciprocal(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def reciprocal(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def is_nan(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def is_nan(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def minimum(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def minimum(x: RaggedTensor, y: _TensorCompatible | RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def minimum(x: _TensorCompatible | RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def maximum(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def maximum(x: RaggedTensor, y: _TensorCompatible | RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def maximum(x: _TensorCompatible | RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def logical_not(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def logical_not(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def logical_and(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def logical_and(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def logical_or(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def logical_or(x: RaggedTensor, y: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def logical_xor(x: _TensorCompatible, y: _TensorCompatible, name: str | None = "LogicalXor") -> Tensor: ... +@overload +def logical_xor(x: RaggedTensor, y: RaggedTensor, name: str | None = "LogicalXor") -> RaggedTensor: ... +@overload +def equal(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def equal(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... +@overload +def not_equal(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def not_equal(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... +@overload +def greater(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def greater(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... +@overload +def greater_equal(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def greater_equal(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... +@overload +def less(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def less(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... +@overload +def less_equal(x: _TensorCompatible, y: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def less_equal(x: RaggedTensor, y: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... +def segment_sum(data: _TensorCompatible, segment_ids: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def sign(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def sign(x: SparseTensor, name: str | None = None) -> SparseTensor: ... +@overload +def sign(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def sqrt(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def sqrt(x: SparseTensor, name: str | None = None) -> SparseTensor: ... +@overload +def sqrt(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def rsqrt(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def rsqrt(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def square(x: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def square(x: SparseTensor, name: str | None = None) -> SparseTensor: ... +@overload +def square(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def softplus(features: _TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def softplus(features: RaggedTensor, name: str | None = None) -> RaggedTensor: ... + +# Depending on the method axis is either a rank 0 tensor or a rank 0/1 tensor. +def reduce_mean( + input_tensor: _TensorCompatible | RaggedTensor, + axis: _TensorCompatible | None = None, + keepdims: bool = False, + name: str | None = None, +) -> Tensor: ... +def reduce_sum( + input_tensor: _TensorCompatible | RaggedTensor, + axis: _TensorCompatible | None = None, + keepdims: bool = False, + name: str | None = None, +) -> Tensor: ... +def reduce_max( + input_tensor: _TensorCompatible | RaggedTensor, + axis: _TensorCompatible | None = None, + keepdims: bool = False, + name: str | None = None, +) -> Tensor: ... +def reduce_min( + input_tensor: _TensorCompatible | RaggedTensor, + axis: _TensorCompatible | None = None, + keepdims: bool = False, + name: str | None = None, +) -> Tensor: ... +def reduce_prod( + input_tensor: _TensorCompatible | RaggedTensor, + axis: _TensorCompatible | None = None, + keepdims: bool = False, + name: str | None = None, +) -> Tensor: ... +def reduce_std( + input_tensor: _TensorCompatible | RaggedTensor, + axis: _TensorCompatible | None = None, + keepdims: bool = False, + name: str | None = None, +) -> Tensor: ... +def argmax( + input: _TensorCompatible, axis: _TensorCompatible | None = None, output_type: _DTypeLike = ..., name: str | None = None +) -> Tensor: ... +def argmin( + input: _TensorCompatible, axis: _TensorCompatible | None = None, output_type: _DTypeLike = ..., name: str | None = None +) -> Tensor: ... + +# Only for bool tensors. +def reduce_any( + input_tensor: _TensorCompatible | RaggedTensor, + axis: _TensorCompatible | None = None, + keepdims: bool = False, + name: str | None = None, +) -> Tensor: ... +def reduce_all( + input_tensor: _TensorCompatible | RaggedTensor, + axis: _TensorCompatible | None = None, + keepdims: bool = False, + name: str | None = None, +) -> Tensor: ... +def count_nonzero( + input: _SparseTensorCompatible, + axis: _TensorCompatible | None = None, + keepdims: bool | None = None, + dtype: _DTypeLike = ..., + name: str | None = None, +) -> Tensor: ... def __getattr__(name: str) -> Incomplete: ...