mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-28 22:56:55 +08:00
tensorflow: add tf.linalg module (#11386)
Taken from: https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/linalg.pyi
This commit is contained in:
@@ -38,6 +38,7 @@ from tensorflow.dtypes import *
|
||||
from tensorflow.dtypes import DType as DType
|
||||
from tensorflow.experimental.dtensor import Layout
|
||||
from tensorflow.keras import losses as losses
|
||||
from tensorflow.linalg import eye as eye
|
||||
|
||||
# Most tf.math functions are exported as tf, but sadly not all are.
|
||||
from tensorflow.math import (
|
||||
|
||||
@@ -27,6 +27,7 @@ class KerasSerializable2(Protocol):
|
||||
|
||||
KerasSerializable: TypeAlias = KerasSerializable1 | KerasSerializable2
|
||||
|
||||
Integer: TypeAlias = tf.Tensor | int | IntArray | np.number[Any] # Here tf.Tensor and IntArray are assumed to be 0D.
|
||||
Slice: TypeAlias = int | slice | None
|
||||
FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence]
|
||||
IntDataSequence: TypeAlias = Sequence[int] | Sequence[IntDataSequence]
|
||||
|
||||
52
stubs/tensorflow/tensorflow/linalg.pyi
Normal file
52
stubs/tensorflow/tensorflow/linalg.pyi
Normal file
@@ -0,0 +1,52 @@
|
||||
from _typeshed import Incomplete
|
||||
from builtins import bool as _bool
|
||||
from collections.abc import Iterable
|
||||
from typing import Literal, overload
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import RaggedTensor, Tensor, norm as norm
|
||||
from tensorflow._aliases import DTypeLike, IntArray, Integer, ScalarTensorCompatible, TensorCompatible
|
||||
from tensorflow.math import l2_normalize as l2_normalize
|
||||
|
||||
@overload
|
||||
def matmul(
|
||||
a: TensorCompatible,
|
||||
b: TensorCompatible,
|
||||
transpose_a: _bool = False,
|
||||
transpose_b: _bool = False,
|
||||
adjoint_a: _bool = False,
|
||||
adjoint_b: _bool = False,
|
||||
a_is_sparse: _bool = False,
|
||||
b_is_sparse: _bool = False,
|
||||
output_type: DTypeLike | None = None,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def matmul(
|
||||
a: RaggedTensor,
|
||||
b: RaggedTensor,
|
||||
transpose_a: _bool = False,
|
||||
transpose_b: _bool = False,
|
||||
adjoint_a: _bool = False,
|
||||
adjoint_b: _bool = False,
|
||||
a_is_sparse: _bool = False,
|
||||
b_is_sparse: _bool = False,
|
||||
output_type: DTypeLike | None = None,
|
||||
name: str | None = None,
|
||||
) -> RaggedTensor: ...
|
||||
def set_diag(
|
||||
input: TensorCompatible,
|
||||
diagonal: TensorCompatible,
|
||||
name: str | None = "set_diag",
|
||||
k: int = 0,
|
||||
align: Literal["RIGHT_LEFT", "RIGHT_RIGHT", "LEFT_LEFT", "LEFT_RIGHT"] = "RIGHT_LEFT",
|
||||
) -> Tensor: ...
|
||||
def eye(
|
||||
num_rows: ScalarTensorCompatible,
|
||||
num_columns: ScalarTensorCompatible | None = None,
|
||||
batch_shape: Iterable[int] | IntArray | tf.Tensor | None = None,
|
||||
dtype: DTypeLike = ...,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
def band_part(input: TensorCompatible, num_lower: Integer, num_upper: Integer, name: str | None = None) -> Tensor: ...
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
Reference in New Issue
Block a user