mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-08 04:54:47 +08:00
tensorflow Add tensorflow.summary module (#11358)
Partially derived from https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/summary.pyi
This commit is contained in:
@@ -91,6 +91,7 @@ tensorflow.io.SparseFeature.__new__
|
||||
|
||||
# Metaclass inconsistency. The runtime metaclass is defined from c++ extension and is undocumented.
|
||||
tensorflow.io.TFRecordWriter
|
||||
tensorflow.experimental.dtensor.Mesh
|
||||
|
||||
# stubtest does not pass for protobuf generated stubs.
|
||||
tensorflow.train.Example.*
|
||||
|
||||
@@ -29,6 +29,7 @@ KerasSerializable: TypeAlias = KerasSerializable1 | KerasSerializable2
|
||||
|
||||
Slice: TypeAlias = int | slice | None
|
||||
FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence]
|
||||
IntDataSequence: TypeAlias = Sequence[int] | Sequence[IntDataSequence]
|
||||
StrDataSequence: TypeAlias = Sequence[str] | Sequence[StrDataSequence]
|
||||
ScalarTensorCompatible: TypeAlias = tf.Tensor | str | float | np.ndarray[Any, Any] | np.number[Any]
|
||||
|
||||
@@ -52,4 +53,6 @@ ContainerInputSpec: TypeAlias = ContainerGeneric[InputSpec]
|
||||
|
||||
AnyArray: TypeAlias = npt.NDArray[Any]
|
||||
FloatArray: TypeAlias = npt.NDArray[np.float_ | np.float16 | np.float32 | np.float64]
|
||||
IntArray: TypeAlias = npt.NDArray[np.int_ | np.uint8 | np.int32 | np.int64]
|
||||
UIntArray: TypeAlias = npt.NDArray[np.uint | np.uint8 | np.uint16 | np.uint32 | np.uint64]
|
||||
SignedIntArray: TypeAlias = npt.NDArray[np.int_ | np.int8 | np.int16 | np.int32 | np.int64]
|
||||
IntArray: TypeAlias = UIntArray | SignedIntArray
|
||||
|
||||
17
stubs/tensorflow/tensorflow/experimental/dtensor.pyi
Normal file
17
stubs/tensorflow/tensorflow/experimental/dtensor.pyi
Normal file
@@ -0,0 +1,17 @@
|
||||
from _typeshed import Incomplete
|
||||
|
||||
from tensorflow._aliases import IntArray, IntDataSequence
|
||||
|
||||
class Mesh:
|
||||
def __init__(
|
||||
self,
|
||||
dim_names: list[str],
|
||||
global_device_ids: IntArray | IntDataSequence,
|
||||
local_device_ids: list[int],
|
||||
local_devices: list[Incomplete | str],
|
||||
mesh_name: str = "",
|
||||
global_devices: list[Incomplete | str] | None = None,
|
||||
use_xla_spmd: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
60
stubs/tensorflow/tensorflow/summary.pyi
Normal file
60
stubs/tensorflow/tensorflow/summary.pyi
Normal file
@@ -0,0 +1,60 @@
|
||||
import abc
|
||||
from _typeshed import Incomplete
|
||||
from collections.abc import Callable, Iterator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Literal
|
||||
from typing_extensions import Self
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow._aliases import FloatArray, IntArray
|
||||
from tensorflow.experimental.dtensor import Mesh
|
||||
|
||||
class SummaryWriter(metaclass=abc.ABCMeta):
|
||||
def as_default(self, step: int | None = None) -> AbstractContextManager[Self]: ...
|
||||
def close(self) -> None: ...
|
||||
def flush(self) -> None: ...
|
||||
def init(self) -> None: ...
|
||||
def set_as_default(self, step: int | None = None) -> None: ...
|
||||
|
||||
def audio(
|
||||
name: str,
|
||||
data: tf.Tensor,
|
||||
sample_rate: int | tf.Tensor,
|
||||
step: int | tf.Tensor | None = None,
|
||||
max_outputs: int | tf.Tensor | None = 3,
|
||||
encoding: Literal["wav"] | None = None,
|
||||
description: str | None = None,
|
||||
) -> bool: ...
|
||||
def create_file_writer(
|
||||
logdir: str,
|
||||
max_queue: int | None = None,
|
||||
flush_millis: int | None = None,
|
||||
filename_suffix: str | None = None,
|
||||
name: str | None = None,
|
||||
experimental_trackable: bool = False,
|
||||
experimental_mesh: Mesh | None = None,
|
||||
) -> SummaryWriter: ...
|
||||
def create_noop_writer() -> SummaryWriter: ...
|
||||
def flush(writer: SummaryWriter | None = None, name: str | None = None) -> tf.Operation: ...
|
||||
def graph(graph_data: tf.Graph | tf.compat.v1.GraphDef) -> bool: ...
|
||||
def histogram(
|
||||
name: str, data: tf.Tensor, step: int | None = None, buckets: int | None = None, description: str | None = None
|
||||
) -> bool: ...
|
||||
def image(
|
||||
name: str,
|
||||
data: tf.Tensor | FloatArray | IntArray,
|
||||
step: int | tf.Tensor | None = None,
|
||||
max_outputs: int | None = 3,
|
||||
description: str | None = None,
|
||||
) -> bool: ...
|
||||
@contextmanager
|
||||
def record_if(condition: bool | tf.Tensor | Callable[[], bool]) -> Iterator[None]: ...
|
||||
def scalar(name: str, data: float | tf.Tensor, step: int | tf.Tensor | None = None, description: str | None = None) -> bool: ...
|
||||
def should_record_summaries() -> bool: ...
|
||||
def text(name: str, data: str | tf.Tensor, step: int | tf.Tensor | None = None, description: str | None = None) -> bool: ...
|
||||
def trace_export(name: str, step: int | tf.Tensor | None = None, profiler_outdir: str | None = None) -> None: ...
|
||||
def trace_off() -> None: ...
|
||||
def trace_on(graph: bool = True, profiler: bool = False) -> None: ...
|
||||
def write(
|
||||
tag: str, tensor: tf.Tensor, step: int | tf.Tensor | None = None, metadata: Incomplete | None = None, name: str | None = None
|
||||
) -> bool: ...
|
||||
Reference in New Issue
Block a user