[tensorflow] Add __slots__ (#15459)

This commit is contained in:
Semyon Moroz
2026-02-23 19:22:06 +00:00
committed by GitHub
parent 2b36c453e8
commit 33b3568a93
4 changed files with 24 additions and 0 deletions
+5
View File
@@ -222,6 +222,7 @@ class Operation:
def __getattr__(self, name: str) -> Incomplete: ...
class TensorShape(metaclass=ABCMeta):
__slots__ = ["_dims"]
def __init__(self, dims: ShapeLike) -> None: ...
@property
def rank(self) -> int: ...
@@ -308,6 +309,7 @@ class UnconnectedGradients(Enum):
_SpecProto = TypeVar("_SpecProto", bound=Message)
class TypeSpec(ABC, Generic[_SpecProto]):
__slots__ = ["_cached_cmp_key"]
@property
@abstractmethod
def value_type(self) -> Any: ...
@@ -323,6 +325,7 @@ class TypeSpec(ABC, Generic[_SpecProto]):
def most_specific_compatible_type(self, other: Self) -> Self: ...
class TensorSpec(TypeSpec[struct_pb2.TensorSpecProto]):
__slots__: list[str] = []
def __init__(self, shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None) -> None: ...
@property
def value_type(self) -> Tensor: ...
@@ -339,6 +342,7 @@ class TensorSpec(TypeSpec[struct_pb2.TensorSpecProto]):
def is_compatible_with(self, spec_or_tensor: Self | TensorCompatible) -> _bool: ... # type: ignore[override]
class SparseTensorSpec(TypeSpec[struct_pb2.TypeSpecProto]):
__slots__ = ["_shape", "_dtype"]
def __init__(self, shape: ShapeLike | None = None, dtype: DTypeLike = ...) -> None: ...
@property
def value_type(self) -> SparseTensor: ...
@@ -350,6 +354,7 @@ class SparseTensorSpec(TypeSpec[struct_pb2.TypeSpecProto]):
def from_value(cls, value: SparseTensor) -> Self: ...
class RaggedTensorSpec(TypeSpec[struct_pb2.TypeSpecProto]):
__slots__ = ["_shape", "_dtype", "_ragged_rank", "_row_splits_dtype", "_flat_values_spec"]
def __init__(
self,
shape: ShapeLike | None = None,
+1
View File
@@ -10,6 +10,7 @@ from tensorflow.python.framework.dtypes import HandleData
class _DTypeMeta(ABCMeta): ...
class DType(metaclass=_DTypeMeta):
__slots__ = ["_handle_data"]
def __init__(self, type_enum: int, handle_data: HandleData | None = None) -> None: ...
@property
def name(self) -> str: ...
@@ -18,6 +18,13 @@ class Asset:
def __init__(self, path: str | Path | tf.Tensor) -> None: ...
class LoadOptions:
__slots__ = (
"allow_partial_checkpoint",
"experimental_io_device",
"experimental_skip_checkpoint",
"experimental_variable_policy",
"experimental_load_function_aliases",
)
allow_partial_checkpoint: bool
experimental_io_device: str | None
experimental_skip_checkpoint: bool
@@ -1,3 +1,4 @@
from _typeshed import Incomplete
from collections.abc import Callable
from typing import Any, TypeVar
from typing_extensions import Self
@@ -18,10 +19,20 @@ from tensorflow.python.trackable.base import Trackable
from tensorflow.python.training.tracking.autotrackable import AutoTrackable
class CheckpointOptions:
__slots__ = (
"experimental_io_device",
"experimental_enable_async_checkpoint",
"experimental_write_callbacks",
"enable_async",
"experimental_sharding_callback",
"experimental_skip_slot_variables",
)
experimental_io_device: None | str
experimental_enable_async_checkpoint: bool
experimental_write_callbacks: None | list[Callable[[str], object] | Callable[[], object]]
enable_async: bool
experimental_sharding_callback: Incomplete # should be ShardingCallback
experimental_skip_slot_variables: bool
def __init__(
self,