mirror of
https://github.com/davidhalter/typeshed.git
synced 2026-05-04 12:35:49 +08:00
[tensorflow] Add __slots__ (#15459)
This commit is contained in:
@@ -222,6 +222,7 @@ class Operation:
|
||||
def __getattr__(self, name: str) -> Incomplete: ...
|
||||
|
||||
class TensorShape(metaclass=ABCMeta):
|
||||
__slots__ = ["_dims"]
|
||||
def __init__(self, dims: ShapeLike) -> None: ...
|
||||
@property
|
||||
def rank(self) -> int: ...
|
||||
@@ -308,6 +309,7 @@ class UnconnectedGradients(Enum):
|
||||
_SpecProto = TypeVar("_SpecProto", bound=Message)
|
||||
|
||||
class TypeSpec(ABC, Generic[_SpecProto]):
|
||||
__slots__ = ["_cached_cmp_key"]
|
||||
@property
|
||||
@abstractmethod
|
||||
def value_type(self) -> Any: ...
|
||||
@@ -323,6 +325,7 @@ class TypeSpec(ABC, Generic[_SpecProto]):
|
||||
def most_specific_compatible_type(self, other: Self) -> Self: ...
|
||||
|
||||
class TensorSpec(TypeSpec[struct_pb2.TensorSpecProto]):
|
||||
__slots__: list[str] = []
|
||||
def __init__(self, shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None) -> None: ...
|
||||
@property
|
||||
def value_type(self) -> Tensor: ...
|
||||
@@ -339,6 +342,7 @@ class TensorSpec(TypeSpec[struct_pb2.TensorSpecProto]):
|
||||
def is_compatible_with(self, spec_or_tensor: Self | TensorCompatible) -> _bool: ... # type: ignore[override]
|
||||
|
||||
class SparseTensorSpec(TypeSpec[struct_pb2.TypeSpecProto]):
|
||||
__slots__ = ["_shape", "_dtype"]
|
||||
def __init__(self, shape: ShapeLike | None = None, dtype: DTypeLike = ...) -> None: ...
|
||||
@property
|
||||
def value_type(self) -> SparseTensor: ...
|
||||
@@ -350,6 +354,7 @@ class SparseTensorSpec(TypeSpec[struct_pb2.TypeSpecProto]):
|
||||
def from_value(cls, value: SparseTensor) -> Self: ...
|
||||
|
||||
class RaggedTensorSpec(TypeSpec[struct_pb2.TypeSpecProto]):
|
||||
__slots__ = ["_shape", "_dtype", "_ragged_rank", "_row_splits_dtype", "_flat_values_spec"]
|
||||
def __init__(
|
||||
self,
|
||||
shape: ShapeLike | None = None,
|
||||
|
||||
@@ -10,6 +10,7 @@ from tensorflow.python.framework.dtypes import HandleData
|
||||
class _DTypeMeta(ABCMeta): ...
|
||||
|
||||
class DType(metaclass=_DTypeMeta):
|
||||
__slots__ = ["_handle_data"]
|
||||
def __init__(self, type_enum: int, handle_data: HandleData | None = None) -> None: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
@@ -18,6 +18,13 @@ class Asset:
|
||||
def __init__(self, path: str | Path | tf.Tensor) -> None: ...
|
||||
|
||||
class LoadOptions:
|
||||
__slots__ = (
|
||||
"allow_partial_checkpoint",
|
||||
"experimental_io_device",
|
||||
"experimental_skip_checkpoint",
|
||||
"experimental_variable_policy",
|
||||
"experimental_load_function_aliases",
|
||||
)
|
||||
allow_partial_checkpoint: bool
|
||||
experimental_io_device: str | None
|
||||
experimental_skip_checkpoint: bool
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from _typeshed import Incomplete
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar
|
||||
from typing_extensions import Self
|
||||
@@ -18,10 +19,20 @@ from tensorflow.python.trackable.base import Trackable
|
||||
from tensorflow.python.training.tracking.autotrackable import AutoTrackable
|
||||
|
||||
class CheckpointOptions:
|
||||
__slots__ = (
|
||||
"experimental_io_device",
|
||||
"experimental_enable_async_checkpoint",
|
||||
"experimental_write_callbacks",
|
||||
"enable_async",
|
||||
"experimental_sharding_callback",
|
||||
"experimental_skip_slot_variables",
|
||||
)
|
||||
experimental_io_device: None | str
|
||||
experimental_enable_async_checkpoint: bool
|
||||
experimental_write_callbacks: None | list[Callable[[str], object] | Callable[[], object]]
|
||||
enable_async: bool
|
||||
experimental_sharding_callback: Incomplete # should be ShardingCallback
|
||||
experimental_skip_slot_variables: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user