Use Generator instead of Iterator for 3rd-party context managers (#12481)

This commit is contained in:
Max Muoto
2024-08-12 07:26:18 -05:00
committed by GitHub
parent 0b6f15c2ff
commit 37807d753a
14 changed files with 22 additions and 22 deletions

View File

@@ -2,7 +2,7 @@ import abc
from _typeshed import Incomplete, Unused
from abc import ABC, ABCMeta, abstractmethod
from builtins import bool as _bool
from collections.abc import Callable, Iterable, Iterator, Sequence
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
from contextlib import contextmanager
from enum import Enum
from types import TracebackType
@@ -228,7 +228,7 @@ class Graph:
def add_to_collection(self, name: str, value: object) -> None: ...
def add_to_collections(self, names: Iterable[str] | str, value: object) -> None: ...
@contextmanager
def as_default(self) -> Iterator[Self]: ...
def as_default(self) -> Generator[Self]: ...
def finalize(self) -> None: ...
def get_tensor_by_name(self, name: str) -> Tensor: ...
def get_operation_by_name(self, name: str) -> Operation: ...

View File

@@ -56,7 +56,7 @@ class GradientTape:
unconnected_gradients: UnconnectedGradients = ...,
) -> ContainerGradients: ...
@contextmanager
def stop_recording(self) -> Generator[None, None, None]: ...
def stop_recording(self) -> Generator[None]: ...
def reset(self) -> None: ...
def watch(self, tensor: ContainerTensorsLike) -> None: ...
def watched_variables(self) -> tuple[Variable, ...]: ...

View File

@@ -1,6 +1,6 @@
import abc
from _typeshed import Incomplete
from collections.abc import Callable, Iterator
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Literal
from typing_extensions import Self
@@ -49,7 +49,7 @@ def image(
description: str | None = None,
) -> bool: ...
@contextmanager
def record_if(condition: bool | tf.Tensor | Callable[[], bool]) -> Iterator[None]: ...
def record_if(condition: bool | tf.Tensor | Callable[[], bool]) -> Generator[None]: ...
def scalar(name: str, data: float | tf.Tensor, step: int | tf.Tensor | None = None, description: str | None = None) -> bool: ...
def should_record_summaries() -> bool: ...
def text(name: str, data: str | tf.Tensor, step: int | tf.Tensor | None = None, description: str | None = None) -> bool: ...