mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-08 04:54:47 +08:00
tensorflow add tensorflow.saved_model (#11439)
Based on: - https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/saved_model/__init__.pyi - https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/types/experimental.pyi
This commit is contained in:
@@ -105,5 +105,11 @@ tensorflow.train.Int64List.*
|
||||
tensorflow.train.ClusterDef.*
|
||||
tensorflow.train.ServerDef.*
|
||||
|
||||
# The python module cannot be accessed directly, so to stubtest it appears that it is not present at runtime.
|
||||
# However it can be accessed by doing:
|
||||
# from tensorflow import python
|
||||
# python.X
|
||||
tensorflow.python.*
|
||||
|
||||
# sigmoid_cross_entropy_with_logits has default values (None), however those values are not valid.
|
||||
tensorflow.nn.sigmoid_cross_entropy_with_logits
|
||||
|
||||
@@ -27,7 +27,8 @@ class KerasSerializable2(Protocol):
|
||||
|
||||
KerasSerializable: TypeAlias = KerasSerializable1 | KerasSerializable2
|
||||
|
||||
Integer: TypeAlias = tf.Tensor | int | IntArray | np.number[Any] # Here tf.Tensor and IntArray are assumed to be 0D.
|
||||
TensorValue: TypeAlias = tf.Tensor # Alias for a 0D Tensor
|
||||
Integer: TypeAlias = TensorValue | int | IntArray | np.number[Any] # Here IntArray are assumed to be 0D.
|
||||
Slice: TypeAlias = int | slice | None
|
||||
FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence]
|
||||
IntDataSequence: TypeAlias = Sequence[int] | Sequence[IntDataSequence]
|
||||
|
||||
11
stubs/tensorflow/tensorflow/python/trackable/resource.pyi
Normal file
11
stubs/tensorflow/tensorflow/python/trackable/resource.pyi
Normal file
@@ -0,0 +1,11 @@
|
||||
from _typeshed import Incomplete
|
||||
|
||||
from tensorflow.python.trackable.base import Trackable
|
||||
|
||||
class _ResourceMetaclass(type): ...
|
||||
|
||||
# Internal type that is commonly used as a base class
|
||||
# it is needed for the public signatures of some APIs.
|
||||
class CapturableResource(Trackable, metaclass=_ResourceMetaclass): ...
|
||||
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
@@ -0,0 +1,3 @@
|
||||
from _typeshed import Incomplete
|
||||
|
||||
AutoTrackable = Incomplete
|
||||
115
stubs/tensorflow/tensorflow/saved_model/__init__.pyi
Normal file
115
stubs/tensorflow/tensorflow/saved_model/__init__.pyi
Normal file
@@ -0,0 +1,115 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, Literal, TypeVar
|
||||
from typing_extensions import ParamSpec, TypeAlias
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.training.tracking.autotrackable import AutoTrackable
|
||||
from tensorflow.saved_model.experimental import VariablePolicy
|
||||
from tensorflow.types.experimental import ConcreteFunction, GenericFunction
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R", covariant=True)
|
||||
|
||||
class Asset:
|
||||
@property
|
||||
def asset_path(self) -> tf.Tensor: ...
|
||||
def __init__(self, path: str | Path | tf.Tensor) -> None: ...
|
||||
|
||||
class LoadOptions:
|
||||
allow_partial_checkpoint: bool
|
||||
experimental_io_device: str | None
|
||||
experimental_skip_checkpoint: bool
|
||||
experimental_variable_policy: VariablePolicy | None
|
||||
experimental_load_function_aliases: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
allow_partial_checkpoint: bool = False,
|
||||
experimental_io_device: str | None = None,
|
||||
experimental_skip_checkpoint: bool = False,
|
||||
experimental_variable_policy: (
|
||||
VariablePolicy | Literal["expand_distributed_variables", "save_variable_devices"] | None
|
||||
) = None,
|
||||
experimental_load_function_aliases: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
class SaveOptions:
|
||||
__slots__ = (
|
||||
"namespace_whitelist",
|
||||
"save_debug_info",
|
||||
"function_aliases",
|
||||
"experimental_io_device",
|
||||
"experimental_variable_policy",
|
||||
"experimental_custom_gradients",
|
||||
"experimental_image_format",
|
||||
"experimental_skip_saver",
|
||||
)
|
||||
namespace_whitelist: list[str]
|
||||
save_debug_info: bool
|
||||
function_aliases: dict[str, tf.types.experimental.GenericFunction[..., object]]
|
||||
experimental_io_device: str
|
||||
experimental_variable_policy: VariablePolicy
|
||||
experimental_custom_gradients: bool
|
||||
experimental_image_format: bool
|
||||
experimental_skip_saver: bool
|
||||
def __init__(
|
||||
self,
|
||||
namespace_whitelist: list[str] | None = None,
|
||||
save_debug_info: bool = False,
|
||||
function_aliases: Mapping[str, tf.types.experimental.GenericFunction[..., object]] | None = None,
|
||||
experimental_io_device: str | None = None,
|
||||
experimental_variable_policy: str | VariablePolicy | None = None,
|
||||
experimental_custom_gradients: bool = True,
|
||||
experimental_image_format: bool = False,
|
||||
experimental_skip_saver: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
def contains_saved_model(export_dir: str | Path) -> bool: ...
|
||||
|
||||
class _LoadedAttributes(Generic[_P, _R]):
|
||||
signatures: Mapping[str, ConcreteFunction[_P, _R]]
|
||||
|
||||
class _LoadedModel(AutoTrackable, _LoadedAttributes[_P, _R]):
|
||||
variables: list[tf.Variable]
|
||||
trainable_variables: list[tf.Variable]
|
||||
# TF1 model artifact specific
|
||||
graph: tf.Graph
|
||||
|
||||
def load(
|
||||
export_dir: str, tags: str | Sequence[str] | None = None, options: LoadOptions | None = None
|
||||
) -> _LoadedModel[..., Any]: ...
|
||||
|
||||
_TF_Function: TypeAlias = ConcreteFunction[..., object] | GenericFunction[..., object]
|
||||
|
||||
def save(
|
||||
obj: tf.Module,
|
||||
export_dir: str,
|
||||
signatures: _TF_Function | Mapping[str, _TF_Function] | None = None,
|
||||
options: SaveOptions | None = None,
|
||||
) -> None: ...
|
||||
|
||||
ASSETS_DIRECTORY: str = "assets"
|
||||
ASSETS_KEY: str = "saved_model_assets"
|
||||
CLASSIFY_INPUTS: str = "inputs"
|
||||
CLASSIFY_METHOD_NAME: str = "tensorflow/serving/classify"
|
||||
CLASSIFY_OUTPUT_CLASSES: str = "classes"
|
||||
CLASSIFY_OUTPUT_SCORES: str = "scores"
|
||||
DEBUG_DIRECTORY: str = "debug"
|
||||
DEBUG_INFO_FILENAME_PB: str = "saved_model_debug_info.pb"
|
||||
DEFAULT_SERVING_SIGNATURE_DEF_KEY: str = "serving_default"
|
||||
GPU: str = "gpu"
|
||||
PREDICT_INPUTS: str = "inputs"
|
||||
PREDICT_METHOD_NAME: str = "tensorflow/serving/predict"
|
||||
PREDICT_OUTPUTS: str = "outputs"
|
||||
REGRESS_INPUTS: str = "inputs"
|
||||
REGRESS_METHOD_NAME: str = "tensorflow/serving/regress"
|
||||
REGRESS_OUTPUTS: str = "outputs"
|
||||
SAVED_MODEL_FILENAME_PB: str = "saved_model.pb"
|
||||
SAVED_MODEL_FILENAME_PBTXT: str = "saved_model.pbtxt"
|
||||
SAVED_MODEL_SCHEMA_VERSION: int = 1
|
||||
SERVING: str = "serve"
|
||||
TPU: str = "tpu"
|
||||
TRAINING: str = "train"
|
||||
VARIABLES_DIRECTORY: str = "variables"
|
||||
VARIABLES_FILENAME: str = "variables"
|
||||
40
stubs/tensorflow/tensorflow/saved_model/experimental.pyi
Normal file
40
stubs/tensorflow/tensorflow/saved_model/experimental.pyi
Normal file
@@ -0,0 +1,40 @@
|
||||
from _typeshed import Incomplete
|
||||
from enum import Enum
|
||||
from typing_extensions import Self
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow._aliases import Integer, TensorValue
|
||||
from tensorflow.python.trackable.resource import CapturableResource
|
||||
|
||||
class Fingerprint:
|
||||
saved_model_checksum: TensorValue | None
|
||||
graph_def_program_hash: TensorValue | None = None
|
||||
signature_def_hash: TensorValue | None = None
|
||||
saved_object_graph_hash: TensorValue | None = None
|
||||
checkpoint_hash: TensorValue | None = None
|
||||
version: TensorValue | None = None
|
||||
# In practice it seems like any type is accepted, but that might cause issues later on.
|
||||
def __init__(
|
||||
self,
|
||||
saved_model_checksum: Integer | None = None,
|
||||
graph_def_program_hash: Integer | None = None,
|
||||
signature_def_hash: Integer | None = None,
|
||||
saved_object_graph_hash: Integer | None = None,
|
||||
checkpoint_hash: Integer | None = None,
|
||||
version: Integer | None = None,
|
||||
) -> None: ...
|
||||
@classmethod
|
||||
def from_proto(cls, proto: Incomplete) -> Self: ...
|
||||
def singleprint(self) -> str: ...
|
||||
|
||||
class TrackableResource(CapturableResource):
|
||||
@property
|
||||
def resource_handle(self) -> tf.Tensor: ...
|
||||
def __init__(self, device: str = "") -> None: ...
|
||||
|
||||
class VariablePolicy(Enum):
|
||||
EXPAND_DISTRIBUTED_VARIABLES = "expand_distributed_variables"
|
||||
NONE = None # noqa: Y026
|
||||
SAVE_VARIABLE_DEVICES = "save_variable_devices"
|
||||
|
||||
def read_fingerprint(export_dir: str) -> Fingerprint: ...
|
||||
29
stubs/tensorflow/tensorflow/types/experimental.pyi
Normal file
29
stubs/tensorflow/tensorflow/types/experimental.pyi
Normal file
@@ -0,0 +1,29 @@
|
||||
import abc
|
||||
from _typeshed import Incomplete
|
||||
from typing import Any, Generic, TypeVar, overload
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow._aliases import ContainerGeneric
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R", covariant=True)
|
||||
|
||||
class Callable(Generic[_P, _R], metaclass=abc.ABCMeta):
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
|
||||
|
||||
class ConcreteFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
|
||||
|
||||
class GenericFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
|
||||
@overload
|
||||
@abc.abstractmethod
|
||||
def get_concrete_function(self, *args: _P.args, **kwargs: _P.kwargs) -> ConcreteFunction[_P, _R]: ...
|
||||
@overload
|
||||
@abc.abstractmethod
|
||||
def get_concrete_function(
|
||||
self, *args: ContainerGeneric[tf.TypeSpec[Any]], **kwargs: ContainerGeneric[tf.TypeSpec[Any]]
|
||||
) -> ConcreteFunction[_P, _R]: ...
|
||||
def experimental_get_compiler_ir(self, *args: Incomplete, **kwargs: Incomplete) -> Incomplete: ...
|
||||
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
Reference in New Issue
Block a user