tensorflow: add tf.linalg module (#11386)

Taken from:
https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/linalg.pyi
This commit is contained in:
Hoël Bagard
2024-02-18 00:45:11 +09:00
committed by GitHub
parent 2e85a70c4c
commit 955cdf50d5
3 changed files with 54 additions and 0 deletions

View File

@@ -38,6 +38,7 @@ from tensorflow.dtypes import *
from tensorflow.dtypes import DType as DType
from tensorflow.experimental.dtensor import Layout
from tensorflow.keras import losses as losses
from tensorflow.linalg import eye as eye
# Most tf.math functions are exported as tf, but sadly not all are.
from tensorflow.math import (

View File

@@ -27,6 +27,7 @@ class KerasSerializable2(Protocol):
KerasSerializable: TypeAlias = KerasSerializable1 | KerasSerializable2
Integer: TypeAlias = tf.Tensor | int | IntArray | np.number[Any] # Here tf.Tensor and IntArray are assumed to be 0D.
Slice: TypeAlias = int | slice | None
FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence]
IntDataSequence: TypeAlias = Sequence[int] | Sequence[IntDataSequence]

View File

@@ -0,0 +1,52 @@
from _typeshed import Incomplete
from builtins import bool as _bool
from collections.abc import Iterable
from typing import Literal, overload
import tensorflow as tf
from tensorflow import RaggedTensor, Tensor, norm as norm
from tensorflow._aliases import DTypeLike, IntArray, Integer, ScalarTensorCompatible, TensorCompatible
from tensorflow.math import l2_normalize as l2_normalize
@overload
def matmul(
a: TensorCompatible,
b: TensorCompatible,
transpose_a: _bool = False,
transpose_b: _bool = False,
adjoint_a: _bool = False,
adjoint_b: _bool = False,
a_is_sparse: _bool = False,
b_is_sparse: _bool = False,
output_type: DTypeLike | None = None,
name: str | None = None,
) -> Tensor: ...
@overload
def matmul(
a: RaggedTensor,
b: RaggedTensor,
transpose_a: _bool = False,
transpose_b: _bool = False,
adjoint_a: _bool = False,
adjoint_b: _bool = False,
a_is_sparse: _bool = False,
b_is_sparse: _bool = False,
output_type: DTypeLike | None = None,
name: str | None = None,
) -> RaggedTensor: ...
def set_diag(
input: TensorCompatible,
diagonal: TensorCompatible,
name: str | None = "set_diag",
k: int = 0,
align: Literal["RIGHT_LEFT", "RIGHT_RIGHT", "LEFT_LEFT", "LEFT_RIGHT"] = "RIGHT_LEFT",
) -> Tensor: ...
def eye(
num_rows: ScalarTensorCompatible,
num_columns: ScalarTensorCompatible | None = None,
batch_shape: Iterable[int] | IntArray | tf.Tensor | None = None,
dtype: DTypeLike = ...,
name: str | None = None,
) -> Tensor: ...
def band_part(input: TensorCompatible, num_lower: Integer, num_upper: Integer, name: str | None = None) -> Tensor: ...
def __getattr__(name: str) -> Incomplete: ...