tensorflow: add tensorflow.autograph (#11443)

This commit is contained in:
Hoël Bagard
2024-02-29 23:57:51 +09:00
committed by GitHub
parent 0ad004a776
commit d52c1f6783
2 changed files with 47 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
from collections.abc import Callable
from typing import Any, TypeVar
from tensorflow.autograph.experimental import Feature
_Type = TypeVar("_Type")
def set_verbosity(level: int, alsologtostdout: bool = False) -> None: ...
def to_code(
entity: Callable[..., Any],
recursive: bool = True,
experimental_optional_features: None | Feature | tuple[Feature, ...] = None,
) -> str: ...
def to_graph(
entity: _Type, recursive: bool = True, experimental_optional_features: None | Feature | tuple[Feature, ...] = None
) -> _Type: ...
def trace(*args: Any) -> None: ...

View File

@@ -0,0 +1,30 @@
from collections.abc import Callable, Iterable
from enum import Enum
from typing import TypeVar, overload
from typing_extensions import ParamSpec
import tensorflow as tf
from tensorflow._aliases import Integer
_Param = ParamSpec("_Param")
_RetType = TypeVar("_RetType")
class Feature(Enum):
ALL = "ALL"
ASSERT_STATEMENTS = "ASSERT_STATEMENTS"
AUTO_CONTROL_DEPS = "AUTO_CONTROL_DEPS"
BUILTIN_FUNCTIONS = "BUILTIN_FUNCTIONS"
EQUALITY_OPERATORS = "EQUALITY_OPERATORS"
LISTS = "LISTS"
NAME_SCOPES = "NAME_SCOPES"
@overload
def do_not_convert(func: Callable[_Param, _RetType]) -> Callable[_Param, _RetType]: ...
@overload
def do_not_convert(func: None = None) -> Callable[[Callable[_Param, _RetType]], Callable[_Param, _RetType]]: ...
def set_loop_options(
parallel_iterations: Integer = ...,
swap_memory: bool = ...,
maximum_iterations: Integer = ...,
shape_invariants: Iterable[tuple[tf.Tensor, tf.TensorShape]] = ...,
) -> None: ...