diff --git a/stubs/tensorflow/tensorflow/autograph/__init__.pyi b/stubs/tensorflow/tensorflow/autograph/__init__.pyi new file mode 100644 index 000000000..4058de0c3 --- /dev/null +++ b/stubs/tensorflow/tensorflow/autograph/__init__.pyi @@ -0,0 +1,17 @@ +from collections.abc import Callable +from typing import Any, TypeVar + +from tensorflow.autograph.experimental import Feature + +_Type = TypeVar("_Type") + +def set_verbosity(level: int, alsologtostdout: bool = False) -> None: ... +def to_code( + entity: Callable[..., Any], + recursive: bool = True, + experimental_optional_features: None | Feature | tuple[Feature, ...] = None, +) -> str: ... +def to_graph( + entity: _Type, recursive: bool = True, experimental_optional_features: None | Feature | tuple[Feature, ...] = None +) -> _Type: ... +def trace(*args: Any) -> None: ... diff --git a/stubs/tensorflow/tensorflow/autograph/experimental.pyi b/stubs/tensorflow/tensorflow/autograph/experimental.pyi new file mode 100644 index 000000000..738a6802d --- /dev/null +++ b/stubs/tensorflow/tensorflow/autograph/experimental.pyi @@ -0,0 +1,30 @@ +from collections.abc import Callable, Iterable +from enum import Enum +from typing import TypeVar, overload +from typing_extensions import ParamSpec + +import tensorflow as tf +from tensorflow._aliases import Integer + +_Param = ParamSpec("_Param") +_RetType = TypeVar("_RetType") + +class Feature(Enum): + ALL = "ALL" + ASSERT_STATEMENTS = "ASSERT_STATEMENTS" + AUTO_CONTROL_DEPS = "AUTO_CONTROL_DEPS" + BUILTIN_FUNCTIONS = "BUILTIN_FUNCTIONS" + EQUALITY_OPERATORS = "EQUALITY_OPERATORS" + LISTS = "LISTS" + NAME_SCOPES = "NAME_SCOPES" + +@overload +def do_not_convert(func: Callable[_Param, _RetType]) -> Callable[_Param, _RetType]: ... +@overload +def do_not_convert(func: None = None) -> Callable[[Callable[_Param, _RetType]], Callable[_Param, _RetType]]: ... +def set_loop_options( + parallel_iterations: Integer = ..., + swap_memory: bool = ..., + maximum_iterations: Integer = ..., + shape_invariants: Iterable[tuple[tf.Tensor, tf.TensorShape]] = ..., +) -> None: ...