mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-20 02:41:16 +08:00
tensorflow: add tensorflow.autodiff (#11442)
This commit is contained in:
@@ -20,6 +20,7 @@ tensorflow.python.feature_column.feature_column_v2.SharedEmbeddingColumnCreator.
|
||||
tensorflow.GradientTape.__getattr__
|
||||
tensorflow.data.Dataset.__getattr__
|
||||
tensorflow.experimental.Optional.__getattr__
|
||||
tensorflow.autodiff.GradientTape.__getattr__
|
||||
|
||||
# The Tensor methods below were removed in 2.14, however they are still defined for the
|
||||
# internal subclasses that are used at runtime/in practice.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from _typeshed import Incomplete, Unused
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from builtins import bool as _bool
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from types import TracebackType
|
||||
@@ -18,18 +18,8 @@ from tensorflow import (
|
||||
keras as keras,
|
||||
math as math,
|
||||
)
|
||||
from tensorflow._aliases import (
|
||||
AnyArray,
|
||||
ContainerGradients,
|
||||
ContainerTensors,
|
||||
ContainerTensorsLike,
|
||||
DTypeLike,
|
||||
Gradients,
|
||||
ShapeLike,
|
||||
Slice,
|
||||
TensorCompatible,
|
||||
TensorLike,
|
||||
)
|
||||
from tensorflow._aliases import AnyArray, DTypeLike, ShapeLike, Slice, TensorCompatible
|
||||
from tensorflow.autodiff import GradientTape as GradientTape
|
||||
from tensorflow.core.protobuf import struct_pb2
|
||||
|
||||
# Explicit import of DType is covered by the wildcard, but
|
||||
@@ -302,50 +292,6 @@ class UnconnectedGradients(Enum):
|
||||
NONE = "none"
|
||||
ZERO = "zero"
|
||||
|
||||
class GradientTape:
|
||||
def __init__(self, persistent: _bool = False, watch_accessed_variables: _bool = True) -> None: ...
|
||||
def __enter__(self) -> Self: ...
|
||||
def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ...
|
||||
# Higher kinded types would be nice here and these overloads are a way to simulate some of them.
|
||||
@overload
|
||||
def gradient(
|
||||
self,
|
||||
target: ContainerTensors,
|
||||
sources: TensorLike,
|
||||
output_gradients: list[Tensor] | None = None,
|
||||
unconnected_gradients: UnconnectedGradients = ...,
|
||||
) -> Gradients: ...
|
||||
@overload
|
||||
def gradient(
|
||||
self,
|
||||
target: ContainerTensors,
|
||||
sources: Sequence[Tensor],
|
||||
output_gradients: list[Tensor] | None = None,
|
||||
unconnected_gradients: UnconnectedGradients = ...,
|
||||
) -> list[Gradients]: ...
|
||||
@overload
|
||||
def gradient(
|
||||
self,
|
||||
target: ContainerTensors,
|
||||
sources: Mapping[str, Tensor],
|
||||
output_gradients: list[Tensor] | None = None,
|
||||
unconnected_gradients: UnconnectedGradients = ...,
|
||||
) -> dict[str, Gradients]: ...
|
||||
@overload
|
||||
def gradient(
|
||||
self,
|
||||
target: ContainerTensors,
|
||||
sources: ContainerTensors,
|
||||
output_gradients: list[Tensor] | None = None,
|
||||
unconnected_gradients: UnconnectedGradients = ...,
|
||||
) -> ContainerGradients: ...
|
||||
@contextmanager
|
||||
def stop_recording(self) -> Generator[None, None, None]: ...
|
||||
def reset(self) -> None: ...
|
||||
def watch(self, tensor: ContainerTensorsLike) -> None: ...
|
||||
def watched_variables(self) -> tuple[Variable, ...]: ...
|
||||
def __getattr__(self, name: str) -> Incomplete: ...
|
||||
|
||||
_SpecProto = TypeVar("_SpecProto", bound=Message)
|
||||
|
||||
class TypeSpec(ABC, Generic[_SpecProto]):
|
||||
|
||||
63
stubs/tensorflow/tensorflow/autodiff.pyi
Normal file
63
stubs/tensorflow/tensorflow/autodiff.pyi
Normal file
@@ -0,0 +1,63 @@
|
||||
from _typeshed import Incomplete
|
||||
from builtins import bool as _bool
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from contextlib import contextmanager
|
||||
from types import TracebackType
|
||||
from typing import overload
|
||||
from typing_extensions import Self
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import Tensor, UnconnectedGradients, Variable
|
||||
from tensorflow._aliases import ContainerGradients, ContainerTensors, ContainerTensorsLike, Gradients, TensorLike
|
||||
|
||||
class ForwardAccumulator:
|
||||
def __init__(self, primals: Tensor, tangents: Tensor) -> None: ...
|
||||
def jvp(
|
||||
self, primals: Tensor, unconnected_gradients: tf.UnconnectedGradients = tf.UnconnectedGradients.NONE # noqa: Y011
|
||||
) -> Tensor | None: ...
|
||||
def __enter__(self) -> Self: ...
|
||||
def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ...
|
||||
|
||||
class GradientTape:
|
||||
def __init__(self, persistent: _bool = False, watch_accessed_variables: _bool = True) -> None: ...
|
||||
def __enter__(self) -> Self: ...
|
||||
def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ...
|
||||
# Higher kinded types would be nice here and these overloads are a way to simulate some of them.
|
||||
@overload
|
||||
def gradient(
|
||||
self,
|
||||
target: ContainerTensors,
|
||||
sources: TensorLike,
|
||||
output_gradients: list[Tensor] | None = None,
|
||||
unconnected_gradients: UnconnectedGradients = ...,
|
||||
) -> Gradients: ...
|
||||
@overload
|
||||
def gradient(
|
||||
self,
|
||||
target: ContainerTensors,
|
||||
sources: Sequence[Tensor],
|
||||
output_gradients: list[Tensor] | None = None,
|
||||
unconnected_gradients: UnconnectedGradients = ...,
|
||||
) -> list[Gradients]: ...
|
||||
@overload
|
||||
def gradient(
|
||||
self,
|
||||
target: ContainerTensors,
|
||||
sources: Mapping[str, Tensor],
|
||||
output_gradients: list[Tensor] | None = None,
|
||||
unconnected_gradients: UnconnectedGradients = ...,
|
||||
) -> dict[str, Gradients]: ...
|
||||
@overload
|
||||
def gradient(
|
||||
self,
|
||||
target: ContainerTensors,
|
||||
sources: ContainerTensors,
|
||||
output_gradients: list[Tensor] | None = None,
|
||||
unconnected_gradients: UnconnectedGradients = ...,
|
||||
) -> ContainerGradients: ...
|
||||
@contextmanager
|
||||
def stop_recording(self) -> Generator[None, None, None]: ...
|
||||
def reset(self) -> None: ...
|
||||
def watch(self, tensor: ContainerTensorsLike) -> None: ...
|
||||
def watched_variables(self) -> tuple[Variable, ...]: ...
|
||||
def __getattr__(self, name: str) -> Incomplete: ...
|
||||
Reference in New Issue
Block a user