mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-07 12:44:28 +08:00
tensorflow: add tf.train.CheckpointOptions and other tf.train members. (#11327)
This commit is contained in:
@@ -85,3 +85,12 @@ tensorflow.io.SparseFeature.__new__
|
||||
|
||||
# Metaclass inconsistency. The runtime metaclass is defined from c++ extension and is undocumented.
|
||||
tensorflow.io.TFRecordWriter
|
||||
|
||||
# stubtest does not pass for protobuf generated stubs.
|
||||
tensorflow.train.Example.*
|
||||
tensorflow.train.BytesList.*
|
||||
tensorflow.train.Feature.*
|
||||
tensorflow.train.FloatList.*
|
||||
tensorflow.train.Int64List.*
|
||||
tensorflow.train.ClusterDef.*
|
||||
tensorflow.train.ServerDef.*
|
||||
|
||||
74
stubs/tensorflow/tensorflow/train/__init__.pyi
Normal file
74
stubs/tensorflow/tensorflow/train/__init__.pyi
Normal file
@@ -0,0 +1,74 @@
|
||||
from _typeshed import Incomplete
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar
|
||||
from typing_extensions import Self
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.core.example.example_pb2 import Example as Example
|
||||
from tensorflow.core.example.feature_pb2 import (
|
||||
BytesList as BytesList,
|
||||
Feature as Feature,
|
||||
Features as Features,
|
||||
FloatList as FloatList,
|
||||
Int64List as Int64List,
|
||||
)
|
||||
from tensorflow.core.protobuf.cluster_pb2 import ClusterDef as ClusterDef
|
||||
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef as ServerDef
|
||||
from tensorflow.python.trackable.base import Trackable
|
||||
|
||||
class CheckpointOptions:
|
||||
experimental_io_device: None | str
|
||||
experimental_enable_async_checkpoint: bool
|
||||
# Uncomment when the stubs' TF version is updated to 2.15
|
||||
# experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]]
|
||||
enable_async: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experimental_io_device: None | str = None,
|
||||
experimental_enable_async_checkpoint: bool = False,
|
||||
# Uncomment when the stubs' TF version is updated to 2.15
|
||||
# experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] = None,
|
||||
enable_async: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
_T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str])
|
||||
|
||||
class ClusterSpec:
|
||||
def __init__(self, cluster: dict[str, _T] | ClusterDef | ClusterSpec) -> None: ...
|
||||
def as_dict(self) -> dict[str, list[str] | tuple[str] | dict[int, str]]: ...
|
||||
def num_tasks(self, job_name: str) -> int: ...
|
||||
|
||||
class _CheckpointLoadStatus:
|
||||
def assert_consumed(self) -> Self: ...
|
||||
def assert_existing_objects_matched(self) -> Self: ...
|
||||
def assert_nontrivial_match(self) -> Self: ...
|
||||
def expect_partial(self) -> Self: ...
|
||||
|
||||
class Checkpoint:
|
||||
def __init__(self, root: Trackable | None = None, **kwargs: Trackable) -> None: ...
|
||||
def read(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ...
|
||||
def restore(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ...
|
||||
def save(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ...
|
||||
# def sync(self) -> None: ... # Uncomment when the stubs' TF version is updated to 2.15
|
||||
def write(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ...
|
||||
|
||||
class CheckpointManager:
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint: Checkpoint,
|
||||
directory: str,
|
||||
max_to_keep: int,
|
||||
keep_checkpoint_every_n_hours: int | None = None,
|
||||
checkpoint_name: str = "ckpt",
|
||||
step_counter: tf.Variable | None = None,
|
||||
checkpoint_interval: int | None = None,
|
||||
init_fn: Callable[[], object] | None = None,
|
||||
) -> None: ...
|
||||
def _sweep(self) -> None: ...
|
||||
|
||||
def latest_checkpoint(checkpoint_dir: str, latest_filename: str | None = None) -> str: ...
|
||||
def load_variable(ckpt_dir_or_file: str, name: str) -> np.ndarray[Any, Any]: ...
|
||||
def list_variables(ckpt_dir_or_file: str) -> list[tuple[str, list[int]]]: ...
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
13
stubs/tensorflow/tensorflow/train/experimental.pyi
Normal file
13
stubs/tensorflow/tensorflow/train/experimental.pyi
Normal file
@@ -0,0 +1,13 @@
|
||||
import abc
|
||||
from _typeshed import Incomplete
|
||||
from typing_extensions import Self
|
||||
|
||||
from tensorflow.python.trackable.base import Trackable
|
||||
|
||||
class PythonState(Trackable, metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def serialize(self) -> str: ...
|
||||
@abc.abstractmethod
|
||||
def deserialize(self, string_value: str) -> Self: ...
|
||||
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
Reference in New Issue
Block a user