mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-06 12:14:27 +08:00
Bump tensorflow to ~=2.18.0 (#12916)
* Tensorflow proto script update * Manual stubtest changes * Use Path for arg type
This commit is contained in:
@@ -53,17 +53,20 @@ TSL_IMPORT_PATTERN = re.compile(r"(\[|\s)tsl\.")
|
||||
XLA_IMPORT_PATTERN = re.compile(r"(\[|\s)xla\.")
|
||||
|
||||
|
||||
def move_tree(source: Path, destination: Path) -> None:
|
||||
"""Move directory and merge if destination already exists.
|
||||
|
||||
Can't use shutil.move because it can't merge existing directories."""
|
||||
print(f"Moving '{source}' to '{destination}'")
|
||||
shutil.copytree(source, destination, dirs_exist_ok=True)
|
||||
shutil.rmtree(source)
|
||||
|
||||
|
||||
def post_creation() -> None:
|
||||
"""Move third-party and fix imports"""
|
||||
# Can't use shutil.move because it can't merge existing directories.
|
||||
print()
|
||||
print(f"Moving '{STUBS_FOLDER}/tsl' to '{STUBS_FOLDER}/tensorflow/tsl'")
|
||||
shutil.copytree(f"{STUBS_FOLDER}/tsl", f"{STUBS_FOLDER}/tensorflow/tsl", dirs_exist_ok=True)
|
||||
shutil.rmtree(f"{STUBS_FOLDER}/tsl")
|
||||
|
||||
print(f"Moving '{STUBS_FOLDER}/xla' to '{STUBS_FOLDER}/tensorflow/compiler/xla'")
|
||||
shutil.copytree(f"{STUBS_FOLDER}/xla", f"{STUBS_FOLDER}/tensorflow/compiler/xla", dirs_exist_ok=True)
|
||||
shutil.rmtree(f"{STUBS_FOLDER}/xla")
|
||||
move_tree(STUBS_FOLDER / "tsl", STUBS_FOLDER / "tensorflow" / "tsl")
|
||||
move_tree(STUBS_FOLDER / "xla", STUBS_FOLDER / "tensorflow" / "compiler" / "xla")
|
||||
|
||||
for path in STUBS_FOLDER.rglob("*_pb2.pyi"):
|
||||
print(f"Fixing imports in '{path}'")
|
||||
@@ -106,6 +109,7 @@ def main() -> None:
|
||||
proto_globs=(
|
||||
f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/xla/*.proto",
|
||||
f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/xla/service/*.proto",
|
||||
f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/xla/tsl/protobuf/*.proto",
|
||||
f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/example/*.proto",
|
||||
f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/framework/*.proto",
|
||||
f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/protobuf/*.proto",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Using an exact number in the specifier for scripts/sync_proto/google_protobuf.py
|
||||
# Using an exact number in the specifier for scripts/sync_protobuf/google_protobuf.py
|
||||
version = "~=5.28.3"
|
||||
upstream_repository = "https://github.com/protocolbuffers/protobuf"
|
||||
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 27.2 on [protobuf v28.3](https://github.com/protocolbuffers/protobuf/releases/tag/v28.3) (python `protobuf==5.28.3`)."
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Whenever you update version here, PACKAGE_VERSION should be updated
|
||||
# in scripts/sync_proto/s2clientprotocol.py and vice-versa.
|
||||
# in scripts/sync_protobuf/s2clientprotocol.py and vice-versa.
|
||||
version = "5.*"
|
||||
upstream_repository = "https://github.com/Blizzard/s2client-proto"
|
||||
requires = ["types-protobuf"]
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# Using an exact number in the specifier for scripts/sync_proto/tensorflow.py
|
||||
version = "~=2.17.1"
|
||||
# Using an exact number in the specifier for scripts/sync_protobuf/tensorflow.py
|
||||
version = "~=2.18.0"
|
||||
upstream_repository = "https://github.com/tensorflow/tensorflow"
|
||||
# requires a version of numpy with a `py.typed` file
|
||||
# see https://github.com/python/typeshed/issues/12551
|
||||
# on why we need the upper bound for numpy
|
||||
requires = ["numpy>=1.20,<2.1.0", "types-protobuf", "types-requests"]
|
||||
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 27.2 on `tensorflow==2.17.1`."
|
||||
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 27.2 on `tensorflow==2.18.0`."
|
||||
partial_stub = true
|
||||
|
||||
[tool.stubtest]
|
||||
|
||||
@@ -219,7 +219,7 @@ global___Kind = Kind
|
||||
@typing.final
|
||||
class HloInstructionProto(google.protobuf.message.Message):
|
||||
"""Serialization of HloInstruction.
|
||||
Next ID: 87
|
||||
Next ID: 90
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
@@ -316,6 +316,9 @@ class HloInstructionProto(google.protobuf.message.Message):
|
||||
LARGEST_FIELD_NUMBER: builtins.int
|
||||
STATISTICS_VIZ_FIELD_NUMBER: builtins.int
|
||||
DOT_SPARSITY_FIELD_NUMBER: builtins.int
|
||||
COLLECTIVE_DEVICE_LIST_FIELD_NUMBER: builtins.int
|
||||
ORIGINAL_VALUE_FIELD_NUMBER: builtins.int
|
||||
IS_COMPOSITE_FIELD_NUMBER: builtins.int
|
||||
name: builtins.str
|
||||
opcode: builtins.str
|
||||
parameter_number: builtins.int
|
||||
@@ -433,6 +436,8 @@ class HloInstructionProto(google.protobuf.message.Message):
|
||||
"""Represents the K value for top-k."""
|
||||
largest: builtins.bool
|
||||
"""Represents the largest flag for top-k."""
|
||||
is_composite: builtins.bool
|
||||
"""Specifies if a call instruction is a composite."""
|
||||
@property
|
||||
def shape(self) -> tensorflow.compiler.xla.xla_data_pb2.ShapeProto: ...
|
||||
@property
|
||||
@@ -497,7 +502,9 @@ class HloInstructionProto(google.protobuf.message.Message):
|
||||
def sharding(self) -> tensorflow.compiler.xla.xla_data_pb2.OpSharding: ...
|
||||
@property
|
||||
def replica_groups(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[tensorflow.compiler.xla.xla_data_pb2.ReplicaGroup]:
|
||||
"""Cross replica op fields."""
|
||||
"""Deprecated, but keeping for backward compatibility.
|
||||
Use collective_device_list. Cross replica op fields.
|
||||
"""
|
||||
|
||||
@property
|
||||
def scatter_dimension_numbers(self) -> tensorflow.compiler.xla.xla_data_pb2.ScatterDimensionNumbers: ...
|
||||
@@ -549,6 +556,14 @@ class HloInstructionProto(google.protobuf.message.Message):
|
||||
def dot_sparsity(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor]:
|
||||
"""Sparsity descriptor for dot operation."""
|
||||
|
||||
@property
|
||||
def collective_device_list(self) -> tensorflow.compiler.xla.xla_data_pb2.CollectiveDeviceListProto:
|
||||
"""Represents the list of devices that participate in a collective operation."""
|
||||
|
||||
@property
|
||||
def original_value(self) -> tensorflow.compiler.xla.xla_data_pb2.OriginalValueProto:
|
||||
"""For HLO value tracking."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -623,9 +638,12 @@ class HloInstructionProto(google.protobuf.message.Message):
|
||||
largest: builtins.bool | None = ...,
|
||||
statistics_viz: tensorflow.compiler.xla.xla_data_pb2.StatisticsViz | None = ...,
|
||||
dot_sparsity: collections.abc.Iterable[tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor] | None = ...,
|
||||
collective_device_list: tensorflow.compiler.xla.xla_data_pb2.CollectiveDeviceListProto | None = ...,
|
||||
original_value: tensorflow.compiler.xla.xla_data_pb2.OriginalValueProto | None = ...,
|
||||
is_composite: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "k", b"k", "largest", b"largest", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "output_operand_aliasing", b"output_operand_aliasing", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "collective_device_list", b"collective_device_list", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "original_value", b"original_value", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "collective_device_list", b"collective_device_list", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_composite", b"is_composite", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "k", b"k", "largest", b"largest", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "original_value", b"original_value", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "output_operand_aliasing", b"output_operand_aliasing", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index"]) -> typing.Literal["cross_program_prefetch_index"] | None: ...
|
||||
|
||||
global___HloInstructionProto = HloInstructionProto
|
||||
@@ -980,6 +998,7 @@ class HloModuleProto(google.protobuf.message.Message):
|
||||
FUSION: HloModuleProto._ProfileType.ValueType # 2
|
||||
LAYOUT: HloModuleProto._ProfileType.ValueType # 3
|
||||
DOT: HloModuleProto._ProfileType.ValueType # 4
|
||||
FLAGNET: HloModuleProto._ProfileType.ValueType # 5
|
||||
|
||||
class ProfileType(_ProfileType, metaclass=_ProfileTypeEnumTypeWrapper):
|
||||
"""The type of optimization profile in use for module-level optimizations."""
|
||||
@@ -989,6 +1008,7 @@ class HloModuleProto(google.protobuf.message.Message):
|
||||
FUSION: HloModuleProto.ProfileType.ValueType # 2
|
||||
LAYOUT: HloModuleProto.ProfileType.ValueType # 3
|
||||
DOT: HloModuleProto.ProfileType.ValueType # 4
|
||||
FLAGNET: HloModuleProto.ProfileType.ValueType # 5
|
||||
|
||||
@typing.final
|
||||
class ProfileInfo(google.protobuf.message.Message):
|
||||
@@ -1604,35 +1624,3 @@ class HloPassMetadata(google.protobuf.message.Message):
|
||||
def ClearField(self, field_name: typing.Literal["custom_metadata", b"custom_metadata", "dump_filenames", b"dump_filenames", "end_timestamp_usec", b"end_timestamp_usec", "module_changed", b"module_changed", "module_group_module_ids", b"module_group_module_ids", "module_id", b"module_id", "pass_id", b"pass_id", "pass_name", b"pass_name", "pipeline_name", b"pipeline_name", "start_timestamp_usec", b"start_timestamp_usec"]) -> None: ...
|
||||
|
||||
global___HloPassMetadata = HloPassMetadata
|
||||
|
||||
@typing.final
|
||||
class XlaRuntimeExecutableProto(google.protobuf.message.Message):
|
||||
"""Encodes the underlying Xla runtime executable compiled from the XLA module."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
HLO_MODULE_PROTO_FIELD_NUMBER: builtins.int
|
||||
OBJ_FILE_FIELD_NUMBER: builtins.int
|
||||
MLIR_MODULE_FIELD_NUMBER: builtins.int
|
||||
obj_file: builtins.bytes
|
||||
"""TODO(b/232263665)): Serialized executable has to know what APIs it has to
|
||||
be linked with, including the version. For example Gpu executable must be
|
||||
linked with a runtime layer that abstracts over CUDA.
|
||||
|
||||
Serialized object file compiled from the XLA module.
|
||||
"""
|
||||
mlir_module: builtins.str
|
||||
"""Serialized MLIR module corresponding to compiled object file."""
|
||||
@property
|
||||
def hlo_module_proto(self) -> global___HloModuleProto: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hlo_module_proto: global___HloModuleProto | None = ...,
|
||||
obj_file: builtins.bytes | None = ...,
|
||||
mlir_module: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["hlo_module_proto", b"hlo_module_proto"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["hlo_module_proto", b"hlo_module_proto", "mlir_module", b"mlir_module", "obj_file", b"obj_file"]) -> None: ...
|
||||
|
||||
global___XlaRuntimeExecutableProto = XlaRuntimeExecutableProto
|
||||
|
||||
@@ -61,6 +61,57 @@ class PassMetrics(google.protobuf.message.Message):
|
||||
|
||||
global___PassMetrics = PassMetrics
|
||||
|
||||
@typing.final
|
||||
class JobInfo(google.protobuf.message.Message):
|
||||
"""Defines compilation job information."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
NAME_FIELD_NUMBER: builtins.int
|
||||
CELL_FIELD_NUMBER: builtins.int
|
||||
USER_FIELD_NUMBER: builtins.int
|
||||
UID_FIELD_NUMBER: builtins.int
|
||||
TASK_ID_FIELD_NUMBER: builtins.int
|
||||
TASK_UID_FIELD_NUMBER: builtins.int
|
||||
name: builtins.str
|
||||
"""Name of the job running compilation."""
|
||||
cell: builtins.str
|
||||
"""Cell in which the job is running."""
|
||||
user: builtins.str
|
||||
"""User running the job."""
|
||||
uid: builtins.int
|
||||
"""Unique id when combined with user and cell field."""
|
||||
task_id: builtins.int
|
||||
"""Task index, which will not change across job restarts."""
|
||||
task_uid: builtins.int
|
||||
"""Task unique id, which may change across job restarts."""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: builtins.str | None = ...,
|
||||
cell: builtins.str | None = ...,
|
||||
user: builtins.str | None = ...,
|
||||
uid: builtins.int | None = ...,
|
||||
task_id: builtins.int | None = ...,
|
||||
task_uid: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_cell", b"_cell", "_name", b"_name", "_task_id", b"_task_id", "_task_uid", b"_task_uid", "_uid", b"_uid", "_user", b"_user", "cell", b"cell", "name", b"name", "task_id", b"task_id", "task_uid", b"task_uid", "uid", b"uid", "user", b"user"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_cell", b"_cell", "_name", b"_name", "_task_id", b"_task_id", "_task_uid", b"_task_uid", "_uid", b"_uid", "_user", b"_user", "cell", b"cell", "name", b"name", "task_id", b"task_id", "task_uid", b"task_uid", "uid", b"uid", "user", b"user"]) -> None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_cell", b"_cell"]) -> typing.Literal["cell"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_name", b"_name"]) -> typing.Literal["name"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_task_id", b"_task_id"]) -> typing.Literal["task_id"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_task_uid", b"_task_uid"]) -> typing.Literal["task_uid"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_uid", b"_uid"]) -> typing.Literal["uid"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_user", b"_user"]) -> typing.Literal["user"] | None: ...
|
||||
|
||||
global___JobInfo = JobInfo
|
||||
|
||||
@typing.final
|
||||
class CompilationLogEntry(google.protobuf.message.Message):
|
||||
"""Defines XLA compilation metrics."""
|
||||
@@ -94,11 +145,13 @@ class CompilationLogEntry(google.protobuf.message.Message):
|
||||
TASK_INDEX_FIELD_NUMBER: builtins.int
|
||||
PASS_METRICS_FIELD_NUMBER: builtins.int
|
||||
MODULE_IDS_FIELD_NUMBER: builtins.int
|
||||
JOB_INFO_FIELD_NUMBER: builtins.int
|
||||
stage: global___CompilationLogEntry.CompilationStage.ValueType
|
||||
"""Compilation stage recorded by this log entry."""
|
||||
task_index: builtins.int
|
||||
"""Task index from which this log entry was recorded or
|
||||
-1 if the task index could not be fetched.
|
||||
-1 if the task index could not be fetched. In the case task_index is not
|
||||
equal to -1, it is guaranteed to match the task_id in job_info.
|
||||
"""
|
||||
@property
|
||||
def timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp:
|
||||
@@ -116,6 +169,10 @@ class CompilationLogEntry(google.protobuf.message.Message):
|
||||
def module_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""IDs of modules on which the compilation stage was run."""
|
||||
|
||||
@property
|
||||
def job_info(self) -> global___JobInfo:
|
||||
"""Job information."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -125,8 +182,9 @@ class CompilationLogEntry(google.protobuf.message.Message):
|
||||
task_index: builtins.int | None = ...,
|
||||
pass_metrics: collections.abc.Iterable[global___PassMetrics] | None = ...,
|
||||
module_ids: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
job_info: global___JobInfo | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["duration", b"duration", "timestamp", b"timestamp"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["duration", b"duration", "module_ids", b"module_ids", "pass_metrics", b"pass_metrics", "stage", b"stage", "task_index", b"task_index", "timestamp", b"timestamp"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["duration", b"duration", "job_info", b"job_info", "timestamp", b"timestamp"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["duration", b"duration", "job_info", b"job_info", "module_ids", b"module_ids", "pass_metrics", b"pass_metrics", "stage", b"stage", "task_index", b"task_index", "timestamp", b"timestamp"]) -> None: ...
|
||||
|
||||
global___CompilationLogEntry = CompilationLogEntry
|
||||
|
||||
@@ -115,7 +115,7 @@ class CompilationResult(google.protobuf.message.Message):
|
||||
|
||||
@property
|
||||
def status(self) -> tensorflow.tsl.protobuf.status_pb2.StatusProto:
|
||||
"""Always set when compilation fails; never set when compilation succeeds."""
|
||||
"""Always set even when compilation succeeds."""
|
||||
|
||||
@property
|
||||
def counters(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.int]:
|
||||
|
||||
@@ -860,6 +860,7 @@ class OpMetadata(google.protobuf.message.Message):
|
||||
DEDUPLICATED_NAME_FIELD_NUMBER: builtins.int
|
||||
PRESERVE_LAYOUT_FIELD_NUMBER: builtins.int
|
||||
STACK_FRAME_ID_FIELD_NUMBER: builtins.int
|
||||
SCHEDULING_NAME_FIELD_NUMBER: builtins.int
|
||||
op_type: builtins.str
|
||||
"""The framework op name that generated this XLA op.
|
||||
|
||||
@@ -901,6 +902,8 @@ class OpMetadata(google.protobuf.message.Message):
|
||||
"""1-based position of the frame in frames flat array.
|
||||
Ids are 1-based to keep 0 value as representation of non-set property.
|
||||
"""
|
||||
scheduling_name: builtins.str
|
||||
"""Instruction name available upon scheduling."""
|
||||
@property
|
||||
def profile_type(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[global___ProfileType.ValueType]:
|
||||
"""Deprecated, use [ProfileInfo][profile_type] instead."""
|
||||
@@ -923,9 +926,10 @@ class OpMetadata(google.protobuf.message.Message):
|
||||
deduplicated_name: builtins.str | None = ...,
|
||||
preserve_layout: builtins.bool | None = ...,
|
||||
stack_frame_id: builtins.int | None = ...,
|
||||
scheduling_name: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["profile_info", b"profile_info"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["deduplicated_name", b"deduplicated_name", "op_name", b"op_name", "op_type", b"op_type", "preserve_layout", b"preserve_layout", "profile_info", b"profile_info", "profile_type", b"profile_type", "size_of_generated_code_in_bytes", b"size_of_generated_code_in_bytes", "size_of_memory_working_set_in_bytes", b"size_of_memory_working_set_in_bytes", "source_file", b"source_file", "source_line", b"source_line", "stack_frame_id", b"stack_frame_id"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["deduplicated_name", b"deduplicated_name", "op_name", b"op_name", "op_type", b"op_type", "preserve_layout", b"preserve_layout", "profile_info", b"profile_info", "profile_type", b"profile_type", "scheduling_name", b"scheduling_name", "size_of_generated_code_in_bytes", b"size_of_generated_code_in_bytes", "size_of_memory_working_set_in_bytes", b"size_of_memory_working_set_in_bytes", "source_file", b"source_file", "source_line", b"source_line", "stack_frame_id", b"stack_frame_id"]) -> None: ...
|
||||
|
||||
global___OpMetadata = OpMetadata
|
||||
|
||||
@@ -1379,6 +1383,8 @@ class GatherDimensionNumbers(google.protobuf.message.Message):
|
||||
COLLAPSED_SLICE_DIMS_FIELD_NUMBER: builtins.int
|
||||
START_INDEX_MAP_FIELD_NUMBER: builtins.int
|
||||
INDEX_VECTOR_DIM_FIELD_NUMBER: builtins.int
|
||||
OPERAND_BATCHING_DIMS_FIELD_NUMBER: builtins.int
|
||||
START_INDICES_BATCHING_DIMS_FIELD_NUMBER: builtins.int
|
||||
index_vector_dim: builtins.int
|
||||
"""The dimension in the start_indices input that contains the starting
|
||||
indices.
|
||||
@@ -1409,6 +1415,16 @@ class GatherDimensionNumbers(google.protobuf.message.Message):
|
||||
the starting index in the input space.
|
||||
"""
|
||||
|
||||
@property
|
||||
def operand_batching_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""This is the batch dimensions in the operand."""
|
||||
|
||||
@property
|
||||
def start_indices_batching_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""This is the batch dimensions in the index, and it should be the same size
|
||||
as operand_batching_dims.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -1416,8 +1432,10 @@ class GatherDimensionNumbers(google.protobuf.message.Message):
|
||||
collapsed_slice_dims: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
start_index_map: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
index_vector_dim: builtins.int | None = ...,
|
||||
operand_batching_dims: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
start_indices_batching_dims: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["collapsed_slice_dims", b"collapsed_slice_dims", "index_vector_dim", b"index_vector_dim", "offset_dims", b"offset_dims", "start_index_map", b"start_index_map"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["collapsed_slice_dims", b"collapsed_slice_dims", "index_vector_dim", b"index_vector_dim", "offset_dims", b"offset_dims", "operand_batching_dims", b"operand_batching_dims", "start_index_map", b"start_index_map", "start_indices_batching_dims", b"start_indices_batching_dims"]) -> None: ...
|
||||
|
||||
global___GatherDimensionNumbers = GatherDimensionNumbers
|
||||
|
||||
@@ -1435,6 +1453,8 @@ class ScatterDimensionNumbers(google.protobuf.message.Message):
|
||||
INSERTED_WINDOW_DIMS_FIELD_NUMBER: builtins.int
|
||||
SCATTER_DIMS_TO_OPERAND_DIMS_FIELD_NUMBER: builtins.int
|
||||
INDEX_VECTOR_DIM_FIELD_NUMBER: builtins.int
|
||||
INPUT_BATCHING_DIMS_FIELD_NUMBER: builtins.int
|
||||
SCATTER_INDICES_BATCHING_DIMS_FIELD_NUMBER: builtins.int
|
||||
index_vector_dim: builtins.int
|
||||
@property
|
||||
def update_window_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
@@ -1446,6 +1466,14 @@ class ScatterDimensionNumbers(google.protobuf.message.Message):
|
||||
|
||||
@property
|
||||
def scatter_dims_to_operand_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
|
||||
@property
|
||||
def input_batching_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""This is the batch dimension in the input."""
|
||||
|
||||
@property
|
||||
def scatter_indices_batching_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""This is the batch dimension in the index."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -1453,8 +1481,10 @@ class ScatterDimensionNumbers(google.protobuf.message.Message):
|
||||
inserted_window_dims: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
scatter_dims_to_operand_dims: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
index_vector_dim: builtins.int | None = ...,
|
||||
input_batching_dims: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
scatter_indices_batching_dims: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["index_vector_dim", b"index_vector_dim", "inserted_window_dims", b"inserted_window_dims", "scatter_dims_to_operand_dims", b"scatter_dims_to_operand_dims", "update_window_dims", b"update_window_dims"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["index_vector_dim", b"index_vector_dim", "input_batching_dims", b"input_batching_dims", "inserted_window_dims", b"inserted_window_dims", "scatter_dims_to_operand_dims", b"scatter_dims_to_operand_dims", "scatter_indices_batching_dims", b"scatter_indices_batching_dims", "update_window_dims", b"update_window_dims"]) -> None: ...
|
||||
|
||||
global___ScatterDimensionNumbers = ScatterDimensionNumbers
|
||||
|
||||
@@ -1981,6 +2011,80 @@ class ReplicaGroup(google.protobuf.message.Message):
|
||||
|
||||
global___ReplicaGroup = ReplicaGroup
|
||||
|
||||
@typing.final
|
||||
class IotaReplicaGroupListProto(google.protobuf.message.Message):
|
||||
"""Represents a list of replica groups (a list of list of devices) with
|
||||
reshaping and transposing an iota array (iota tile assignment). Can be used
|
||||
to represent certain common patterns of device lists in a compact, scalable
|
||||
format.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
NUM_REPLICA_GROUPS_FIELD_NUMBER: builtins.int
|
||||
NUM_DEVICES_PER_GROUP_FIELD_NUMBER: builtins.int
|
||||
IOTA_RESHAPE_DIMS_FIELD_NUMBER: builtins.int
|
||||
IOTA_TRANSPOSE_PERM_FIELD_NUMBER: builtins.int
|
||||
num_replica_groups: builtins.int
|
||||
"""Number of replica groups."""
|
||||
num_devices_per_group: builtins.int
|
||||
"""Number of devices per group."""
|
||||
@property
|
||||
def iota_reshape_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""The dimensions used to reshape the 1D iota array of device IDs."""
|
||||
|
||||
@property
|
||||
def iota_transpose_perm(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""The dimension permutations to transposed the iota array reshaped to
|
||||
iota_reshape_dims. This must have the same size as iota_reshape_dims.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_replica_groups: builtins.int | None = ...,
|
||||
num_devices_per_group: builtins.int | None = ...,
|
||||
iota_reshape_dims: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
iota_transpose_perm: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["iota_reshape_dims", b"iota_reshape_dims", "iota_transpose_perm", b"iota_transpose_perm", "num_devices_per_group", b"num_devices_per_group", "num_replica_groups", b"num_replica_groups"]) -> None: ...
|
||||
|
||||
global___IotaReplicaGroupListProto = IotaReplicaGroupListProto
|
||||
|
||||
@typing.final
|
||||
class CollectiveDeviceListProto(google.protobuf.message.Message):
|
||||
"""Represents a series of devices participating in a collective operation (e.g.,
|
||||
all-reduce and all-to-all). While this directly translates to a list of
|
||||
replica groups, it may be used to represent these lists in a compact form.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
REPLICA_GROUPS_FIELD_NUMBER: builtins.int
|
||||
IOTA_REPLICA_GROUP_LIST_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def replica_groups(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ReplicaGroup]:
|
||||
"""ReplicaGroupV1: List of replica groups. Legacy way of representing device
|
||||
lists.
|
||||
"""
|
||||
|
||||
@property
|
||||
def iota_replica_group_list(self) -> global___IotaReplicaGroupListProto:
|
||||
"""ReplicaGroupV2: Represents a list of replica groups with reshaping and
|
||||
transposing an iota array.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
replica_groups: collections.abc.Iterable[global___ReplicaGroup] | None = ...,
|
||||
iota_replica_group_list: global___IotaReplicaGroupListProto | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["iota_replica_group_list", b"iota_replica_group_list"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["iota_replica_group_list", b"iota_replica_group_list", "replica_groups", b"replica_groups"]) -> None: ...
|
||||
|
||||
global___CollectiveDeviceListProto = CollectiveDeviceListProto
|
||||
|
||||
@typing.final
|
||||
class SourceTarget(google.protobuf.message.Message):
|
||||
"""Describes the source target pair in the collective permute op."""
|
||||
@@ -2236,3 +2340,42 @@ class OutputOperandAliasing(google.protobuf.message.Message):
|
||||
def ClearField(self, field_name: typing.Literal["operand_index", b"operand_index", "operand_shape_index", b"operand_shape_index", "output_shape_index", b"output_shape_index"]) -> None: ...
|
||||
|
||||
global___OutputOperandAliasing = OutputOperandAliasing
|
||||
|
||||
@typing.final
|
||||
class OriginalArrayProto(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
LEAF_SHAPE_INDEX_FIELD_NUMBER: builtins.int
|
||||
INSTRUCTION_NAME_FIELD_NUMBER: builtins.int
|
||||
SHAPE_INDEX_FIELD_NUMBER: builtins.int
|
||||
instruction_name: builtins.str
|
||||
@property
|
||||
def leaf_shape_index(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
|
||||
@property
|
||||
def shape_index(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
leaf_shape_index: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
instruction_name: builtins.str | None = ...,
|
||||
shape_index: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["instruction_name", b"instruction_name", "leaf_shape_index", b"leaf_shape_index", "shape_index", b"shape_index"]) -> None: ...
|
||||
|
||||
global___OriginalArrayProto = OriginalArrayProto
|
||||
|
||||
@typing.final
|
||||
class OriginalValueProto(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
LEAVES_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def leaves(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___OriginalArrayProto]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
leaves: collections.abc.Iterable[global___OriginalArrayProto] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["leaves", b"leaves"]) -> None: ...
|
||||
|
||||
global___OriginalValueProto = OriginalValueProto
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -278,6 +278,25 @@ class OptimizationOptions(google.protobuf.message.Message):
|
||||
|
||||
global___OptimizationOptions = OptimizationOptions
|
||||
|
||||
@typing.final
|
||||
class ServiceOptions(google.protobuf.message.Message):
|
||||
"""next: 2"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
PINNED_FIELD_NUMBER: builtins.int
|
||||
pinned: builtins.bool
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pinned: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["optional_pinned", b"optional_pinned", "pinned", b"pinned"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["optional_pinned", b"optional_pinned", "pinned", b"pinned"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_pinned", b"optional_pinned"]) -> typing.Literal["pinned"] | None: ...
|
||||
|
||||
global___ServiceOptions = ServiceOptions
|
||||
|
||||
@typing.final
|
||||
class ThreadingOptions(google.protobuf.message.Message):
|
||||
"""next: 3"""
|
||||
@@ -308,7 +327,7 @@ class Options(google.protobuf.message.Message):
|
||||
"""Message stored with Dataset objects to control how datasets are processed and
|
||||
optimized.
|
||||
|
||||
next: 12
|
||||
next: 13
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
@@ -319,6 +338,7 @@ class Options(google.protobuf.message.Message):
|
||||
AUTOTUNE_OPTIONS_FIELD_NUMBER: builtins.int
|
||||
DISTRIBUTE_OPTIONS_FIELD_NUMBER: builtins.int
|
||||
OPTIMIZATION_OPTIONS_FIELD_NUMBER: builtins.int
|
||||
SERVICE_OPTIONS_FIELD_NUMBER: builtins.int
|
||||
SLACK_FIELD_NUMBER: builtins.int
|
||||
THREADING_OPTIONS_FIELD_NUMBER: builtins.int
|
||||
EXTERNAL_STATE_POLICY_FIELD_NUMBER: builtins.int
|
||||
@@ -346,6 +366,10 @@ class Options(google.protobuf.message.Message):
|
||||
def optimization_options(self) -> global___OptimizationOptions:
|
||||
"""The optimization options associated with the dataset."""
|
||||
|
||||
@property
|
||||
def service_options(self) -> global___ServiceOptions:
|
||||
"""The tf.data service options associated with the dataset."""
|
||||
|
||||
@property
|
||||
def threading_options(self) -> global___ThreadingOptions:
|
||||
"""The threading options associated with the dataset."""
|
||||
@@ -359,14 +383,15 @@ class Options(google.protobuf.message.Message):
|
||||
autotune_options: global___AutotuneOptions | None = ...,
|
||||
distribute_options: global___DistributeOptions | None = ...,
|
||||
optimization_options: global___OptimizationOptions | None = ...,
|
||||
service_options: global___ServiceOptions | None = ...,
|
||||
slack: builtins.bool | None = ...,
|
||||
threading_options: global___ThreadingOptions | None = ...,
|
||||
external_state_policy: global___ExternalStatePolicy.ValueType | None = ...,
|
||||
symbolic_checkpoint: builtins.bool | None = ...,
|
||||
warm_start: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "dataset_name", b"dataset_name", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_dataset_name", b"optional_dataset_name", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "optional_warm_start", b"optional_warm_start", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options", "warm_start", b"warm_start"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "dataset_name", b"dataset_name", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "framework_type", b"framework_type", "optimization_options", b"optimization_options", "optional_dataset_name", b"optional_dataset_name", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "optional_warm_start", b"optional_warm_start", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options", "warm_start", b"warm_start"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "dataset_name", b"dataset_name", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_dataset_name", b"optional_dataset_name", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "optional_warm_start", b"optional_warm_start", "service_options", b"service_options", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options", "warm_start", b"warm_start"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "dataset_name", b"dataset_name", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "framework_type", b"framework_type", "optimization_options", b"optimization_options", "optional_dataset_name", b"optional_dataset_name", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "optional_warm_start", b"optional_warm_start", "service_options", b"service_options", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options", "warm_start", b"warm_start"]) -> None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_dataset_name", b"optional_dataset_name"]) -> typing.Literal["dataset_name"] | None: ...
|
||||
@typing.overload
|
||||
|
||||
@@ -4,7 +4,7 @@ isort:skip_file
|
||||
"""
|
||||
|
||||
import google.protobuf.descriptor
|
||||
from tensorflow.tsl.protobuf.bfc_memory_map_pb2 import (
|
||||
from tensorflow.compiler.xla.tsl.protobuf.bfc_memory_map_pb2 import (
|
||||
BinSummary as BinSummary,
|
||||
MemAllocatorStats as MemAllocatorStats,
|
||||
MemChunk as MemChunk,
|
||||
|
||||
@@ -773,6 +773,7 @@ class ConfigProto(google.protobuf.message.Message):
|
||||
XLA_FUSION_AUTOTUNER_THRESH_FIELD_NUMBER: builtins.int
|
||||
USE_TFRT_FIELD_NUMBER: builtins.int
|
||||
ENABLE_MULTI_HOST_FIELD_NUMBER: builtins.int
|
||||
TFRT_USE_IFRT_FIELD_NUMBER: builtins.int
|
||||
BACKEND_SERVER_PORT_FIELD_NUMBER: builtins.int
|
||||
TARGET_TPU_FIELD_NUMBER: builtins.int
|
||||
TARGET_GPU_FIELD_NUMBER: builtins.int
|
||||
@@ -886,6 +887,10 @@ class ConfigProto(google.protobuf.message.Message):
|
||||
"""Whether runtime execution uses TFRT."""
|
||||
enable_multi_host: builtins.bool
|
||||
"""If true, use Pathways with TFRT API for multi host support."""
|
||||
tfrt_use_ifrt: builtins.bool
|
||||
"""If true, use ifrt as the backend for TFRT. This is only used when
|
||||
`use_tfrt` is true.
|
||||
"""
|
||||
backend_server_port: builtins.int
|
||||
"""Port for the Pathways server. Ignored if enable_multi_host=false."""
|
||||
target_tpu: builtins.bool
|
||||
@@ -962,6 +967,7 @@ class ConfigProto(google.protobuf.message.Message):
|
||||
xla_fusion_autotuner_thresh: builtins.int | None = ...,
|
||||
use_tfrt: builtins.bool | None = ...,
|
||||
enable_multi_host: builtins.bool | None = ...,
|
||||
tfrt_use_ifrt: builtins.bool | None = ...,
|
||||
backend_server_port: builtins.int | None = ...,
|
||||
target_tpu: builtins.bool | None = ...,
|
||||
target_gpu: builtins.bool | None = ...,
|
||||
@@ -973,7 +979,7 @@ class ConfigProto(google.protobuf.message.Message):
|
||||
disable_eager_executor_streaming_enqueue: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["coordination_config", b"coordination_config", "session_metadata", b"session_metadata"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["backend_server_port", b"backend_server_port", "collective_deterministic_sequential_execution", b"collective_deterministic_sequential_execution", "collective_group_leader", b"collective_group_leader", "collective_nccl", b"collective_nccl", "coordination_config", b"coordination_config", "disable_eager_executor_streaming_enqueue", b"disable_eager_executor_streaming_enqueue", "disable_functional_ops_lowering", b"disable_functional_ops_lowering", "disable_optimize_for_static_graph", b"disable_optimize_for_static_graph", "disable_output_partition_graphs", b"disable_output_partition_graphs", "disable_thread_spinning", b"disable_thread_spinning", "enable_mlir_bridge", b"enable_mlir_bridge", "enable_mlir_graph_optimization", b"enable_mlir_graph_optimization", "enable_multi_host", b"enable_multi_host", "executor_type", b"executor_type", "mlir_bridge_rollout", b"mlir_bridge_rollout", "optimize_for_static_graph", b"optimize_for_static_graph", "recv_buf_max_chunk", b"recv_buf_max_chunk", "session_metadata", b"session_metadata", "share_cluster_devices_in_session", b"share_cluster_devices_in_session", "share_session_state_in_clusterspec_propagation", b"share_session_state_in_clusterspec_propagation", "stream_merge_threshold", b"stream_merge_threshold", "target_gpu", b"target_gpu", "target_tpu", b"target_tpu", "use_numa_affinity", b"use_numa_affinity", "use_tfrt", b"use_tfrt", "xla_fusion_autotuner_thresh", b"xla_fusion_autotuner_thresh", "xla_prefer_single_graph_cluster", b"xla_prefer_single_graph_cluster"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["backend_server_port", b"backend_server_port", "collective_deterministic_sequential_execution", b"collective_deterministic_sequential_execution", "collective_group_leader", b"collective_group_leader", "collective_nccl", b"collective_nccl", "coordination_config", b"coordination_config", "disable_eager_executor_streaming_enqueue", b"disable_eager_executor_streaming_enqueue", "disable_functional_ops_lowering", b"disable_functional_ops_lowering", "disable_optimize_for_static_graph", b"disable_optimize_for_static_graph", "disable_output_partition_graphs", b"disable_output_partition_graphs", "disable_thread_spinning", b"disable_thread_spinning", "enable_mlir_bridge", b"enable_mlir_bridge", "enable_mlir_graph_optimization", b"enable_mlir_graph_optimization", "enable_multi_host", b"enable_multi_host", "executor_type", b"executor_type", "mlir_bridge_rollout", b"mlir_bridge_rollout", "optimize_for_static_graph", b"optimize_for_static_graph", "recv_buf_max_chunk", b"recv_buf_max_chunk", "session_metadata", b"session_metadata", "share_cluster_devices_in_session", b"share_cluster_devices_in_session", "share_session_state_in_clusterspec_propagation", b"share_session_state_in_clusterspec_propagation", "stream_merge_threshold", b"stream_merge_threshold", "target_gpu", b"target_gpu", "target_tpu", b"target_tpu", "tfrt_use_ifrt", b"tfrt_use_ifrt", "use_numa_affinity", b"use_numa_affinity", "use_tfrt", b"use_tfrt", "xla_fusion_autotuner_thresh", b"xla_fusion_autotuner_thresh", "xla_prefer_single_graph_cluster", b"xla_prefer_single_graph_cluster"]) -> None: ...
|
||||
|
||||
DEVICE_COUNT_FIELD_NUMBER: builtins.int
|
||||
INTRA_OP_PARALLELISM_THREADS_FIELD_NUMBER: builtins.int
|
||||
@@ -983,6 +989,7 @@ class ConfigProto(google.protobuf.message.Message):
|
||||
PLACEMENT_PERIOD_FIELD_NUMBER: builtins.int
|
||||
DEVICE_FILTERS_FIELD_NUMBER: builtins.int
|
||||
GPU_OPTIONS_FIELD_NUMBER: builtins.int
|
||||
PLUGGABLE_DEVICE_OPTIONS_FIELD_NUMBER: builtins.int
|
||||
ALLOW_SOFT_PLACEMENT_FIELD_NUMBER: builtins.int
|
||||
LOG_DEVICE_PLACEMENT_FIELD_NUMBER: builtins.int
|
||||
GRAPH_OPTIONS_FIELD_NUMBER: builtins.int
|
||||
@@ -1104,6 +1111,10 @@ class ConfigProto(google.protobuf.message.Message):
|
||||
def gpu_options(self) -> global___GPUOptions:
|
||||
"""Options that apply to all GPUs."""
|
||||
|
||||
@property
|
||||
def pluggable_device_options(self) -> global___GPUOptions:
|
||||
"""Options that apply to pluggable devices."""
|
||||
|
||||
@property
|
||||
def graph_options(self) -> global___GraphOptions:
|
||||
"""Options that apply to all graphs."""
|
||||
@@ -1129,6 +1140,7 @@ class ConfigProto(google.protobuf.message.Message):
|
||||
placement_period: builtins.int | None = ...,
|
||||
device_filters: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
gpu_options: global___GPUOptions | None = ...,
|
||||
pluggable_device_options: global___GPUOptions | None = ...,
|
||||
allow_soft_placement: builtins.bool | None = ...,
|
||||
log_device_placement: builtins.bool | None = ...,
|
||||
graph_options: global___GraphOptions | None = ...,
|
||||
@@ -1139,8 +1151,8 @@ class ConfigProto(google.protobuf.message.Message):
|
||||
share_cluster_devices_in_session: builtins.bool | None = ...,
|
||||
experimental: global___ConfigProto.Experimental | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["cluster_def", b"cluster_def", "experimental", b"experimental", "gpu_options", b"gpu_options", "graph_options", b"graph_options", "rpc_options", b"rpc_options"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["allow_soft_placement", b"allow_soft_placement", "cluster_def", b"cluster_def", "device_count", b"device_count", "device_filters", b"device_filters", "experimental", b"experimental", "gpu_options", b"gpu_options", "graph_options", b"graph_options", "inter_op_parallelism_threads", b"inter_op_parallelism_threads", "intra_op_parallelism_threads", b"intra_op_parallelism_threads", "isolate_session_state", b"isolate_session_state", "log_device_placement", b"log_device_placement", "operation_timeout_in_ms", b"operation_timeout_in_ms", "placement_period", b"placement_period", "rpc_options", b"rpc_options", "session_inter_op_thread_pool", b"session_inter_op_thread_pool", "share_cluster_devices_in_session", b"share_cluster_devices_in_session", "use_per_session_threads", b"use_per_session_threads"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["cluster_def", b"cluster_def", "experimental", b"experimental", "gpu_options", b"gpu_options", "graph_options", b"graph_options", "pluggable_device_options", b"pluggable_device_options", "rpc_options", b"rpc_options"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["allow_soft_placement", b"allow_soft_placement", "cluster_def", b"cluster_def", "device_count", b"device_count", "device_filters", b"device_filters", "experimental", b"experimental", "gpu_options", b"gpu_options", "graph_options", b"graph_options", "inter_op_parallelism_threads", b"inter_op_parallelism_threads", "intra_op_parallelism_threads", b"intra_op_parallelism_threads", "isolate_session_state", b"isolate_session_state", "log_device_placement", b"log_device_placement", "operation_timeout_in_ms", b"operation_timeout_in_ms", "placement_period", b"placement_period", "pluggable_device_options", b"pluggable_device_options", "rpc_options", b"rpc_options", "session_inter_op_thread_pool", b"session_inter_op_thread_pool", "share_cluster_devices_in_session", b"share_cluster_devices_in_session", "use_per_session_threads", b"use_per_session_threads"]) -> None: ...
|
||||
|
||||
global___ConfigProto = ConfigProto
|
||||
|
||||
|
||||
@@ -527,61 +527,6 @@ global___TensorInfo = TensorInfo
|
||||
class SignatureDef(google.protobuf.message.Message):
|
||||
"""SignatureDef defines the signature of a computation supported by a TensorFlow
|
||||
graph.
|
||||
|
||||
For example, a model with two loss computations, sharing a single input,
|
||||
might have the following signature_def map, in a MetaGraphDef message.
|
||||
|
||||
Note that across the two SignatureDefs "loss_A" and "loss_B", the input key,
|
||||
output key, and method_name are identical, and will be used by system(s) that
|
||||
implement or rely upon this particular loss method. The output tensor names
|
||||
differ, demonstrating how different outputs can exist for the same method.
|
||||
|
||||
signature_def {
|
||||
key: "loss_A"
|
||||
value {
|
||||
inputs {
|
||||
key: "input"
|
||||
value {
|
||||
name: "input:0"
|
||||
dtype: DT_STRING
|
||||
tensor_shape: ...
|
||||
}
|
||||
}
|
||||
outputs {
|
||||
key: "loss_output"
|
||||
value {
|
||||
name: "loss_output_A:0"
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: ...
|
||||
}
|
||||
}
|
||||
method_name: "some/package/compute_loss"
|
||||
}
|
||||
...
|
||||
}
|
||||
signature_def {
|
||||
key: "loss_B"
|
||||
value {
|
||||
inputs {
|
||||
key: "input"
|
||||
value {
|
||||
name: "input:0"
|
||||
dtype: DT_STRING
|
||||
tensor_shape: ...
|
||||
}
|
||||
}
|
||||
outputs {
|
||||
key: "loss_output"
|
||||
value {
|
||||
name: "loss_output_B:0"
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: ...
|
||||
}
|
||||
}
|
||||
method_name: "some/package/compute_loss"
|
||||
}
|
||||
...
|
||||
}
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
@@ -645,14 +590,13 @@ class SignatureDef(google.protobuf.message.Message):
|
||||
METHOD_NAME_FIELD_NUMBER: builtins.int
|
||||
DEFAULTS_FIELD_NUMBER: builtins.int
|
||||
method_name: builtins.str
|
||||
"""Extensible method_name information enabling third-party users to mark a
|
||||
SignatureDef as supporting a particular method. This enables producers and
|
||||
consumers of SignatureDefs, e.g. a model definition library and a serving
|
||||
library to have a clear hand-off regarding the semantics of a computation.
|
||||
"""Deprecated: TensorFlow 2 always sets this to a fixed value;
|
||||
open-source TF Serving stopped checking by default since release 2.4.
|
||||
|
||||
Note that multiple SignatureDefs in a single MetaGraphDef may have the same
|
||||
method_name. This is commonly used to support multi-headed computation,
|
||||
where a single graph computation may return multiple results.
|
||||
In TensorFlow 1, the method_name enabled users to mark a SignatureDef as
|
||||
supporting a particular method. Multiple SignatureDefs in a single
|
||||
MetaGraphDef could have the same method_name (e.g., to support multi-headed
|
||||
computation).
|
||||
"""
|
||||
@property
|
||||
def inputs(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___TensorInfo]:
|
||||
|
||||
@@ -117,7 +117,7 @@ global___DispatcherConfig = DispatcherConfig
|
||||
@typing.final
|
||||
class WorkerConfig(google.protobuf.message.Message):
|
||||
"""Configuration for a tf.data service WorkerServer.
|
||||
Next id: 13
|
||||
Next id: 14
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
@@ -130,6 +130,7 @@ class WorkerConfig(google.protobuf.message.Message):
|
||||
HEARTBEAT_INTERVAL_MS_FIELD_NUMBER: builtins.int
|
||||
DISPATCHER_TIMEOUT_MS_FIELD_NUMBER: builtins.int
|
||||
DATA_TRANSFER_PROTOCOL_FIELD_NUMBER: builtins.int
|
||||
DATA_TRANSFER_PORT_FIELD_NUMBER: builtins.int
|
||||
DATA_TRANSFER_ADDRESS_FIELD_NUMBER: builtins.int
|
||||
CROSS_TRAINER_CACHE_SIZE_BYTES_FIELD_NUMBER: builtins.int
|
||||
SNAPSHOT_MAX_CHUNK_SIZE_BYTES_FIELD_NUMBER: builtins.int
|
||||
@@ -157,11 +158,21 @@ class WorkerConfig(google.protobuf.message.Message):
|
||||
runtime.
|
||||
"""
|
||||
data_transfer_protocol: builtins.str
|
||||
"""The protocol for the worker to use when transferring data to clients."""
|
||||
"""If set, the name of an alternative data transfer protocol for which the
|
||||
worker starts an additional server ("data transfer server"); the trainer
|
||||
can then get data from this server. If not set, no such server is started,
|
||||
and the trainer can only get data from the regular worker server over
|
||||
`protocol`.
|
||||
"""
|
||||
data_transfer_port: builtins.int
|
||||
"""If `data_transfer_protocol` is set, the port to which the data transfer
|
||||
server binds. If set to `0`, the server binds to any available port.
|
||||
"""
|
||||
data_transfer_address: builtins.str
|
||||
"""The data transfer address of the worker server. The substring "%port%", if
|
||||
specified, will be replaced with the worker's bound port. This is useful
|
||||
when the port is set to `0`.
|
||||
"""If `data_transfer_protocol` is set, the address of the data transfer
|
||||
server. The substring "%dts_port%" can be used to represent -- and is
|
||||
replaced with -- the bound port of the data transfer server; this is useful
|
||||
when `data_transfer_port` is set to `0`.
|
||||
"""
|
||||
cross_trainer_cache_size_bytes: builtins.int
|
||||
"""Maximum size of the cross-trainer cache in bytes. If enabled, make sure
|
||||
@@ -195,11 +206,12 @@ class WorkerConfig(google.protobuf.message.Message):
|
||||
heartbeat_interval_ms: builtins.int | None = ...,
|
||||
dispatcher_timeout_ms: builtins.int | None = ...,
|
||||
data_transfer_protocol: builtins.str | None = ...,
|
||||
data_transfer_port: builtins.int | None = ...,
|
||||
data_transfer_address: builtins.str | None = ...,
|
||||
cross_trainer_cache_size_bytes: builtins.int | None = ...,
|
||||
snapshot_max_chunk_size_bytes: builtins.int | None = ...,
|
||||
shutdown_quiet_period_ms: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["cross_trainer_cache_size_bytes", b"cross_trainer_cache_size_bytes", "data_transfer_address", b"data_transfer_address", "data_transfer_protocol", b"data_transfer_protocol", "dispatcher_address", b"dispatcher_address", "dispatcher_timeout_ms", b"dispatcher_timeout_ms", "heartbeat_interval_ms", b"heartbeat_interval_ms", "port", b"port", "protocol", b"protocol", "shutdown_quiet_period_ms", b"shutdown_quiet_period_ms", "snapshot_max_chunk_size_bytes", b"snapshot_max_chunk_size_bytes", "worker_address", b"worker_address", "worker_tags", b"worker_tags"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["cross_trainer_cache_size_bytes", b"cross_trainer_cache_size_bytes", "data_transfer_address", b"data_transfer_address", "data_transfer_port", b"data_transfer_port", "data_transfer_protocol", b"data_transfer_protocol", "dispatcher_address", b"dispatcher_address", "dispatcher_timeout_ms", b"dispatcher_timeout_ms", "heartbeat_interval_ms", b"heartbeat_interval_ms", "port", b"port", "protocol", b"protocol", "shutdown_quiet_period_ms", b"shutdown_quiet_period_ms", "snapshot_max_chunk_size_bytes", b"snapshot_max_chunk_size_bytes", "worker_address", b"worker_address", "worker_tags", b"worker_tags"]) -> None: ...
|
||||
|
||||
global___WorkerConfig = WorkerConfig
|
||||
|
||||
@@ -4,7 +4,7 @@ isort:skip_file
|
||||
Protocol messages for describing the results of benchmarks and unit tests."""
|
||||
|
||||
import google.protobuf.descriptor
|
||||
from tensorflow.tsl.protobuf.test_log_pb2 import (
|
||||
from tensorflow.compiler.xla.tsl.protobuf.test_log_pb2 import (
|
||||
AvailableDeviceInfo as AvailableDeviceInfo,
|
||||
BenchmarkEntries as BenchmarkEntries,
|
||||
BenchmarkEntry as BenchmarkEntry,
|
||||
|
||||
@@ -121,7 +121,9 @@ class Dataset(ABC, Generic[_T1]):
|
||||
self,
|
||||
map_func: Callable[..., _T2],
|
||||
num_parallel_calls: int | None = None,
|
||||
deterministic: None | bool = None,
|
||||
deterministic: bool | None = None,
|
||||
synchronous: bool | None = None,
|
||||
use_unbounded_threadpool: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Dataset[_T2]: ...
|
||||
def options(self) -> Options: ...
|
||||
|
||||
@@ -42,7 +42,6 @@ class CallbackList(Callback):
|
||||
model: Model[Any, Any] | None = None,
|
||||
**params: Any,
|
||||
) -> None: ...
|
||||
def append(self, callback: Callback) -> None: ...
|
||||
def set_params(self, params: dict[str, Any]) -> None: ...
|
||||
def set_model(self, model: Model[Any, Any]) -> None: ...
|
||||
def on_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
|
||||
@@ -31,6 +31,7 @@ class InputSpec:
|
||||
axes: dict[int, int | None] | None = None,
|
||||
allow_last_axis_squeeze: bool = False,
|
||||
name: str | None = None,
|
||||
optional: bool = False,
|
||||
) -> None: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
@classmethod
|
||||
@@ -384,6 +385,7 @@ class MultiHeadAttention(Layer[Any, tf.Tensor]):
|
||||
activity_regularizer: _Regularizer | None = None,
|
||||
kernel_constraint: _Constraint | None = None,
|
||||
bias_constraint: _Constraint | None = None,
|
||||
seed: int | None = None,
|
||||
*,
|
||||
# **kwargs passed to Layer
|
||||
trainable: bool = True,
|
||||
|
||||
@@ -34,6 +34,7 @@ class BinaryCrossentropy(Loss):
|
||||
axis: int = -1,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "binary_crossentropy",
|
||||
dtype: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
@@ -48,6 +49,7 @@ class BinaryFocalCrossentropy(Loss):
|
||||
axis: int = -1,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "binary_focal_crossentropy",
|
||||
dtype: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
@@ -59,59 +61,100 @@ class CategoricalCrossentropy(Loss):
|
||||
axis: int = -1,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "categorical_crossentropy",
|
||||
dtype: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class CategoricalHinge(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "categorical_hinge") -> None: ...
|
||||
def __init__(
|
||||
self,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "categorical_hinge",
|
||||
dtype: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class CosineSimilarity(Loss):
|
||||
def __init__(
|
||||
self, axis: int = -1, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "cosine_similarity"
|
||||
self,
|
||||
axis: int = -1,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "cosine_similarity",
|
||||
dtype: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class Hinge(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "hinge") -> None: ...
|
||||
def __init__(
|
||||
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "hinge", dtype: Incomplete | None = None
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class Huber(Loss):
|
||||
def __init__(
|
||||
self, delta: float = 1.0, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "huber_loss"
|
||||
self,
|
||||
delta: float = 1.0,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "huber_loss",
|
||||
dtype: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class KLDivergence(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "kl_divergence") -> None: ...
|
||||
def __init__(
|
||||
self,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "kl_divergence",
|
||||
dtype: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class LogCosh(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "log_cosh") -> None: ...
|
||||
def __init__(
|
||||
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "log_cosh", dtype: Incomplete | None = None
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class MeanAbsoluteError(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_absolute_error") -> None: ...
|
||||
def __init__(
|
||||
self,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "mean_absolute_error",
|
||||
dtype: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class MeanAbsolutePercentageError(Loss):
|
||||
def __init__(
|
||||
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_absolute_percentage_error"
|
||||
self,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "mean_absolute_percentage_error",
|
||||
dtype: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class MeanSquaredError(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_squared_error") -> None: ...
|
||||
def __init__(
|
||||
self,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "mean_squared_error",
|
||||
dtype: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class MeanSquaredLogarithmicError(Loss):
|
||||
def __init__(
|
||||
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_squared_logarithmic_error"
|
||||
self,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "mean_squared_logarithmic_error",
|
||||
dtype: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class Poisson(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "poisson") -> None: ...
|
||||
def __init__(
|
||||
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "poisson", dtype: Incomplete | None = None
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class SparseCategoricalCrossentropy(Loss):
|
||||
@@ -121,11 +164,17 @@ class SparseCategoricalCrossentropy(Loss):
|
||||
ignore_class: int | None = None,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str = "sparse_categorical_crossentropy",
|
||||
dtype: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class SquaredHinge(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "squared_hinge") -> None: ...
|
||||
def __init__(
|
||||
self,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "squared_hinge",
|
||||
dtype: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class Reduction:
|
||||
|
||||
@@ -66,6 +66,7 @@ class Model(Layer[_InputT, _OutputT]):
|
||||
y: TensorCompatible | None = None,
|
||||
y_pred: TensorCompatible | None = None,
|
||||
sample_weight: Incomplete | None = None,
|
||||
training: bool = True,
|
||||
) -> tf.Tensor | None: ...
|
||||
def compute_metrics(
|
||||
self, x: TensorCompatible, y: TensorCompatible, y_pred: TensorCompatible, sample_weight: Incomplete | None = None
|
||||
@@ -137,7 +138,7 @@ class Model(Layer[_InputT, _OutputT]):
|
||||
@property
|
||||
def non_trainable_weights(self) -> list[Variable]: ...
|
||||
def get_weights(self): ...
|
||||
def save(self, filepath: str | Path, overwrite: bool = True) -> None: ...
|
||||
def save(self, filepath: str | Path, overwrite: bool = True, zipped: bool | None = None) -> None: ...
|
||||
def save_weights(self, filepath: str | Path, overwrite: bool = True) -> None: ...
|
||||
# kwargs are from keras.saving.saving_api.load_weights
|
||||
def load_weights(self, filepath: str | Path, skip_mismatch: bool = False, *, by_name: bool = False) -> None: ...
|
||||
@@ -161,6 +162,6 @@ class Model(Layer[_InputT, _OutputT]):
|
||||
def get_layer(self, name: str | None = None, index: int | None = None) -> Layer[Incomplete, Incomplete]: ...
|
||||
def get_compile_config(self) -> dict[str, Any]: ...
|
||||
def compile_from_config(self, config: dict[str, Any]) -> Self: ...
|
||||
def export(self, filepath: str | Path, format: str = "tf_saved_model") -> None: ...
|
||||
def export(self, filepath: str | Path, format: str = "tf_saved_model", verbose: bool = True) -> None: ...
|
||||
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
|
||||
@@ -54,6 +54,7 @@ class CoordinationServiceConfig(google.protobuf.message.Message):
|
||||
RECOVERABLE_JOBS_FIELD_NUMBER: builtins.int
|
||||
ALLOW_NEW_INCARNATION_TO_RECONNECT_FIELD_NUMBER: builtins.int
|
||||
FORCE_DISABLE_FIELD_NUMBER: builtins.int
|
||||
POLL_FOR_ERROR_FROM_SERVICE_AT_STARTUP_FIELD_NUMBER: builtins.int
|
||||
service_type: builtins.str
|
||||
"""Type of coordination service implementation to enable.
|
||||
For example, setting the service type as "standalone" starts a service
|
||||
@@ -95,6 +96,10 @@ class CoordinationServiceConfig(google.protobuf.message.Message):
|
||||
not specify any config. This field allows users to explicitly disable
|
||||
coordination service under all situations.
|
||||
"""
|
||||
poll_for_error_from_service_at_startup: builtins.bool
|
||||
"""Use long polling to get error from coordination service as the error
|
||||
propagation mechanism.
|
||||
"""
|
||||
@property
|
||||
def coordinated_job_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___CoordinatedJob]: ...
|
||||
@property
|
||||
@@ -119,7 +124,8 @@ class CoordinationServiceConfig(google.protobuf.message.Message):
|
||||
recoverable_jobs: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
allow_new_incarnation_to_reconnect: builtins.bool | None = ...,
|
||||
force_disable: builtins.bool | None = ...,
|
||||
poll_for_error_from_service_at_startup: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["agent_destruction_without_shutdown", b"agent_destruction_without_shutdown", "allow_new_incarnation_to_reconnect", b"allow_new_incarnation_to_reconnect", "cluster_register_timeout_in_ms", b"cluster_register_timeout_in_ms", "coordinated_job_list", b"coordinated_job_list", "enable_health_check", b"enable_health_check", "force_disable", b"force_disable", "heartbeat_timeout_in_ms", b"heartbeat_timeout_in_ms", "recoverable_jobs", b"recoverable_jobs", "service_leader", b"service_leader", "service_type", b"service_type", "shutdown_barrier_timeout_in_ms", b"shutdown_barrier_timeout_in_ms"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["agent_destruction_without_shutdown", b"agent_destruction_without_shutdown", "allow_new_incarnation_to_reconnect", b"allow_new_incarnation_to_reconnect", "cluster_register_timeout_in_ms", b"cluster_register_timeout_in_ms", "coordinated_job_list", b"coordinated_job_list", "enable_health_check", b"enable_health_check", "force_disable", b"force_disable", "heartbeat_timeout_in_ms", b"heartbeat_timeout_in_ms", "poll_for_error_from_service_at_startup", b"poll_for_error_from_service_at_startup", "recoverable_jobs", b"recoverable_jobs", "service_leader", b"service_leader", "service_type", b"service_type", "shutdown_barrier_timeout_in_ms", b"shutdown_barrier_timeout_in_ms"]) -> None: ...
|
||||
|
||||
global___CoordinationServiceConfig = CoordinationServiceConfig
|
||||
|
||||
@@ -234,6 +234,33 @@ class HeartbeatResponse(google.protobuf.message.Message):
|
||||
|
||||
global___HeartbeatResponse = HeartbeatResponse
|
||||
|
||||
@typing.final
|
||||
class PollForErrorRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
SOURCE_TASK_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def source_task(self) -> global___CoordinatedTask: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
source_task: global___CoordinatedTask | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["source_task", b"source_task"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["source_task", b"source_task"]) -> None: ...
|
||||
|
||||
global___PollForErrorRequest = PollForErrorRequest
|
||||
|
||||
@typing.final
|
||||
class PollForErrorResponse(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
||||
global___PollForErrorResponse = PollForErrorResponse
|
||||
|
||||
@typing.final
|
||||
class WaitForAllTasksRequest(google.protobuf.message.Message):
|
||||
"""Request and response messages for waiting for all tasks."""
|
||||
|
||||
Reference in New Issue
Block a user