diff --git a/stubs/tensorflow/tensorflow/bitwise.pyi b/stubs/tensorflow/tensorflow/bitwise.pyi new file mode 100644 index 000000000..2045d9987 --- /dev/null +++ b/stubs/tensorflow/tensorflow/bitwise.pyi @@ -0,0 +1,37 @@ +from typing import Any, overload +from typing_extensions import TypeAlias + +import numpy as np +import tensorflow as tf +from tensorflow._aliases import FloatArray, IntArray + +# The alias below is not fully accurate, since TensorFlow casts the inputs, they have some additional +# requirements. For example y needs to be castable into x's dtype. Moreover, x and y cannot both be booleans. +# Properly typing the bitwise functions would be overly complicated and unlikely to provide much benefits +# since most people use Tensors, it was therefore not done. +_BitwiseCompatible: TypeAlias = tf.Tensor | int | FloatArray | IntArray | np.number[Any] + +@overload +def bitwise_and(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ... +@overload +def bitwise_and(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ... +@overload +def bitwise_or(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ... +@overload +def bitwise_or(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ... +@overload +def bitwise_xor(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ... +@overload +def bitwise_xor(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ... +@overload +def invert(x: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ... +@overload +def invert(x: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ... +@overload +def left_shift(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ... +@overload +def left_shift(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ... +@overload +def right_shift(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ... +@overload +def right_shift(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ...