mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-08 04:54:47 +08:00
tensorflow: add tf.ones, tf.zeros, tf.zeros_like and tf.ones_like functions (#11368)
This commit is contained in:
@@ -36,6 +36,7 @@ from tensorflow.core.protobuf import struct_pb2
|
||||
# is necessary to avoid a crash in pytype.
|
||||
from tensorflow.dtypes import *
|
||||
from tensorflow.dtypes import DType as DType
|
||||
from tensorflow.experimental.dtensor import Layout
|
||||
from tensorflow.keras import losses as losses
|
||||
|
||||
# Most tf.math functions are exported as tf, but sadly not all are.
|
||||
@@ -438,5 +439,23 @@ def cast(x: TensorCompatible, dtype: DTypeLike, name: str | None = None) -> Tens
|
||||
def cast(x: SparseTensor, dtype: DTypeLike, name: str | None = None) -> SparseTensor: ...
|
||||
@overload
|
||||
def cast(x: RaggedTensor, dtype: DTypeLike, name: str | None = None) -> RaggedTensor: ...
|
||||
def zeros(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ...
|
||||
def ones(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def zeros_like(
|
||||
input: TensorCompatible | IndexedSlices, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def zeros_like(
|
||||
input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
|
||||
) -> RaggedTensor: ...
|
||||
@overload
|
||||
def ones_like(
|
||||
input: TensorCompatible, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def ones_like(
|
||||
input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None
|
||||
) -> RaggedTensor: ...
|
||||
def reshape(tensor: TensorCompatible, shape: ShapeLike | Tensor, name: str | None = None) -> Tensor: ...
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
|
||||
@@ -2,6 +2,8 @@ from _typeshed import Incomplete
|
||||
|
||||
from tensorflow._aliases import IntArray, IntDataSequence
|
||||
|
||||
Layout = Incomplete
|
||||
|
||||
class Mesh:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user