tensorflow: add tensorflow.bitwise (#11440)

Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
This commit is contained in:
Hoël Bagard
2024-03-02 09:00:02 +09:00
committed by GitHub
parent 9e5bced2d1
commit 5e4483618a

View File

@@ -0,0 +1,37 @@
from typing import Any, overload
from typing_extensions import TypeAlias
import numpy as np
import tensorflow as tf
from tensorflow._aliases import FloatArray, IntArray
# The alias below is not fully accurate, since TensorFlow casts the inputs, they have some additional
# requirements. For example y needs to be castable into x's dtype. Moreover, x and y cannot both be booleans.
# Properly typing the bitwise functions would be overly complicated and unlikely to provide much benefits
# since most people use Tensors, it was therefore not done.
_BitwiseCompatible: TypeAlias = tf.Tensor | int | FloatArray | IntArray | np.number[Any]
@overload
def bitwise_and(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ...
@overload
def bitwise_and(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ...
@overload
def bitwise_or(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ...
@overload
def bitwise_or(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ...
@overload
def bitwise_xor(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ...
@overload
def bitwise_xor(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ...
@overload
def invert(x: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ...
@overload
def invert(x: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ...
@overload
def left_shift(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ...
@overload
def left_shift(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ...
@overload
def right_shift(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ...
@overload
def right_shift(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ...