tensorflow: add tensorflow.autodiff (#11442)

This commit is contained in:
Hoël Bagard
2024-02-29 23:58:47 +09:00
committed by GitHub
parent d52c1f6783
commit 791dc9120a
3 changed files with 67 additions and 57 deletions

View File

@@ -20,6 +20,7 @@ tensorflow.python.feature_column.feature_column_v2.SharedEmbeddingColumnCreator.
tensorflow.GradientTape.__getattr__
tensorflow.data.Dataset.__getattr__
tensorflow.experimental.Optional.__getattr__
tensorflow.autodiff.GradientTape.__getattr__
# The Tensor methods below were removed in 2.14, however they are still defined for the
# internal subclasses that are used at runtime/in practice.

View File

@@ -1,7 +1,7 @@
from _typeshed import Incomplete, Unused
from abc import ABC, ABCMeta, abstractmethod
from builtins import bool as _bool
from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
from contextlib import contextmanager
from enum import Enum
from types import TracebackType
@@ -18,18 +18,8 @@ from tensorflow import (
keras as keras,
math as math,
)
from tensorflow._aliases import (
AnyArray,
ContainerGradients,
ContainerTensors,
ContainerTensorsLike,
DTypeLike,
Gradients,
ShapeLike,
Slice,
TensorCompatible,
TensorLike,
)
from tensorflow._aliases import AnyArray, DTypeLike, ShapeLike, Slice, TensorCompatible
from tensorflow.autodiff import GradientTape as GradientTape
from tensorflow.core.protobuf import struct_pb2
# Explicit import of DType is covered by the wildcard, but
@@ -302,50 +292,6 @@ class UnconnectedGradients(Enum):
NONE = "none"
ZERO = "zero"
class GradientTape:
def __init__(self, persistent: _bool = False, watch_accessed_variables: _bool = True) -> None: ...
def __enter__(self) -> Self: ...
def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ...
# Higher kinded types would be nice here and these overloads are a way to simulate some of them.
@overload
def gradient(
self,
target: ContainerTensors,
sources: TensorLike,
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> Gradients: ...
@overload
def gradient(
self,
target: ContainerTensors,
sources: Sequence[Tensor],
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> list[Gradients]: ...
@overload
def gradient(
self,
target: ContainerTensors,
sources: Mapping[str, Tensor],
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> dict[str, Gradients]: ...
@overload
def gradient(
self,
target: ContainerTensors,
sources: ContainerTensors,
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> ContainerGradients: ...
@contextmanager
def stop_recording(self) -> Generator[None, None, None]: ...
def reset(self) -> None: ...
def watch(self, tensor: ContainerTensorsLike) -> None: ...
def watched_variables(self) -> tuple[Variable, ...]: ...
def __getattr__(self, name: str) -> Incomplete: ...
_SpecProto = TypeVar("_SpecProto", bound=Message)
class TypeSpec(ABC, Generic[_SpecProto]):

View File

@@ -0,0 +1,63 @@
from _typeshed import Incomplete
from builtins import bool as _bool
from collections.abc import Generator, Mapping, Sequence
from contextlib import contextmanager
from types import TracebackType
from typing import overload
from typing_extensions import Self
import tensorflow as tf
from tensorflow import Tensor, UnconnectedGradients, Variable
from tensorflow._aliases import ContainerGradients, ContainerTensors, ContainerTensorsLike, Gradients, TensorLike
class ForwardAccumulator:
def __init__(self, primals: Tensor, tangents: Tensor) -> None: ...
def jvp(
self, primals: Tensor, unconnected_gradients: tf.UnconnectedGradients = tf.UnconnectedGradients.NONE # noqa: Y011
) -> Tensor | None: ...
def __enter__(self) -> Self: ...
def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ...
class GradientTape:
def __init__(self, persistent: _bool = False, watch_accessed_variables: _bool = True) -> None: ...
def __enter__(self) -> Self: ...
def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ...
# Higher kinded types would be nice here and these overloads are a way to simulate some of them.
@overload
def gradient(
self,
target: ContainerTensors,
sources: TensorLike,
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> Gradients: ...
@overload
def gradient(
self,
target: ContainerTensors,
sources: Sequence[Tensor],
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> list[Gradients]: ...
@overload
def gradient(
self,
target: ContainerTensors,
sources: Mapping[str, Tensor],
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> dict[str, Gradients]: ...
@overload
def gradient(
self,
target: ContainerTensors,
sources: ContainerTensors,
output_gradients: list[Tensor] | None = None,
unconnected_gradients: UnconnectedGradients = ...,
) -> ContainerGradients: ...
@contextmanager
def stop_recording(self) -> Generator[None, None, None]: ...
def reset(self) -> None: ...
def watch(self, tensor: ContainerTensorsLike) -> None: ...
def watched_variables(self) -> tuple[Variable, ...]: ...
def __getattr__(self, name: str) -> Incomplete: ...