Bump tensorflow to ~=2.18.0 (#12916)

* Tensorflow proto script update

* Manual stubtest changes

* Use Path for arg type
This commit is contained in:
Avasam
2024-10-28 23:32:31 -04:00
committed by GitHub
parent e92f98ccda
commit 335cc91b70
24 changed files with 746 additions and 929 deletions

View File

@@ -53,17 +53,20 @@ TSL_IMPORT_PATTERN = re.compile(r"(\[|\s)tsl\.")
XLA_IMPORT_PATTERN = re.compile(r"(\[|\s)xla\.")
def move_tree(source: Path, destination: Path) -> None:
"""Move directory and merge if destination already exists.
Can't use shutil.move because it can't merge existing directories."""
print(f"Moving '{source}' to '{destination}'")
shutil.copytree(source, destination, dirs_exist_ok=True)
shutil.rmtree(source)
def post_creation() -> None:
"""Move third-party and fix imports"""
# Can't use shutil.move because it can't merge existing directories.
print()
print(f"Moving '{STUBS_FOLDER}/tsl' to '{STUBS_FOLDER}/tensorflow/tsl'")
shutil.copytree(f"{STUBS_FOLDER}/tsl", f"{STUBS_FOLDER}/tensorflow/tsl", dirs_exist_ok=True)
shutil.rmtree(f"{STUBS_FOLDER}/tsl")
print(f"Moving '{STUBS_FOLDER}/xla' to '{STUBS_FOLDER}/tensorflow/compiler/xla'")
shutil.copytree(f"{STUBS_FOLDER}/xla", f"{STUBS_FOLDER}/tensorflow/compiler/xla", dirs_exist_ok=True)
shutil.rmtree(f"{STUBS_FOLDER}/xla")
move_tree(STUBS_FOLDER / "tsl", STUBS_FOLDER / "tensorflow" / "tsl")
move_tree(STUBS_FOLDER / "xla", STUBS_FOLDER / "tensorflow" / "compiler" / "xla")
for path in STUBS_FOLDER.rglob("*_pb2.pyi"):
print(f"Fixing imports in '{path}'")
@@ -106,6 +109,7 @@ def main() -> None:
proto_globs=(
f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/xla/*.proto",
f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/xla/service/*.proto",
f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/xla/tsl/protobuf/*.proto",
f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/example/*.proto",
f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/framework/*.proto",
f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/protobuf/*.proto",

View File

@@ -1,4 +1,4 @@
# Using an exact number in the specifier for scripts/sync_proto/google_protobuf.py
# Using an exact number in the specifier for scripts/sync_protobuf/google_protobuf.py
version = "~=5.28.3"
upstream_repository = "https://github.com/protocolbuffers/protobuf"
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 27.2 on [protobuf v28.3](https://github.com/protocolbuffers/protobuf/releases/tag/v28.3) (python `protobuf==5.28.3`)."

View File

@@ -1,5 +1,5 @@
# Whenever you update version here, PACKAGE_VERSION should be updated
# in scripts/sync_proto/s2clientprotocol.py and vice-versa.
# in scripts/sync_protobuf/s2clientprotocol.py and vice-versa.
version = "5.*"
upstream_repository = "https://github.com/Blizzard/s2client-proto"
requires = ["types-protobuf"]

View File

@@ -1,11 +1,11 @@
# Using an exact number in the specifier for scripts/sync_proto/tensorflow.py
version = "~=2.17.1"
# Using an exact number in the specifier for scripts/sync_protobuf/tensorflow.py
version = "~=2.18.0"
upstream_repository = "https://github.com/tensorflow/tensorflow"
# requires a version of numpy with a `py.typed` file
# see https://github.com/python/typeshed/issues/12551
# on why we need the upper bound for numpy
requires = ["numpy>=1.20,<2.1.0", "types-protobuf", "types-requests"]
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 27.2 on `tensorflow==2.17.1`."
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 27.2 on `tensorflow==2.18.0`."
partial_stub = true
[tool.stubtest]

View File

@@ -219,7 +219,7 @@ global___Kind = Kind
@typing.final
class HloInstructionProto(google.protobuf.message.Message):
"""Serialization of HloInstruction.
Next ID: 87
Next ID: 90
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -316,6 +316,9 @@ class HloInstructionProto(google.protobuf.message.Message):
LARGEST_FIELD_NUMBER: builtins.int
STATISTICS_VIZ_FIELD_NUMBER: builtins.int
DOT_SPARSITY_FIELD_NUMBER: builtins.int
COLLECTIVE_DEVICE_LIST_FIELD_NUMBER: builtins.int
ORIGINAL_VALUE_FIELD_NUMBER: builtins.int
IS_COMPOSITE_FIELD_NUMBER: builtins.int
name: builtins.str
opcode: builtins.str
parameter_number: builtins.int
@@ -433,6 +436,8 @@ class HloInstructionProto(google.protobuf.message.Message):
"""Represents the K value for top-k."""
largest: builtins.bool
"""Represents the largest flag for top-k."""
is_composite: builtins.bool
"""Specifies if a call instruction is a composite."""
@property
def shape(self) -> tensorflow.compiler.xla.xla_data_pb2.ShapeProto: ...
@property
@@ -497,7 +502,9 @@ class HloInstructionProto(google.protobuf.message.Message):
def sharding(self) -> tensorflow.compiler.xla.xla_data_pb2.OpSharding: ...
@property
def replica_groups(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[tensorflow.compiler.xla.xla_data_pb2.ReplicaGroup]:
"""Cross replica op fields."""
"""Deprecated, but keeping for backward compatibility.
Use collective_device_list. Cross replica op fields.
"""
@property
def scatter_dimension_numbers(self) -> tensorflow.compiler.xla.xla_data_pb2.ScatterDimensionNumbers: ...
@@ -549,6 +556,14 @@ class HloInstructionProto(google.protobuf.message.Message):
def dot_sparsity(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor]:
"""Sparsity descriptor for dot operation."""
@property
def collective_device_list(self) -> tensorflow.compiler.xla.xla_data_pb2.CollectiveDeviceListProto:
"""Represents the list of devices that participate in a collective operation."""
@property
def original_value(self) -> tensorflow.compiler.xla.xla_data_pb2.OriginalValueProto:
"""For HLO value tracking."""
def __init__(
self,
*,
@@ -623,9 +638,12 @@ class HloInstructionProto(google.protobuf.message.Message):
largest: builtins.bool | None = ...,
statistics_viz: tensorflow.compiler.xla.xla_data_pb2.StatisticsViz | None = ...,
dot_sparsity: collections.abc.Iterable[tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor] | None = ...,
collective_device_list: tensorflow.compiler.xla.xla_data_pb2.CollectiveDeviceListProto | None = ...,
original_value: tensorflow.compiler.xla.xla_data_pb2.OriginalValueProto | None = ...,
is_composite: builtins.bool | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "k", b"k", "largest", b"largest", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "output_operand_aliasing", b"output_operand_aliasing", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ...
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "collective_device_list", b"collective_device_list", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "original_value", b"original_value", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "collective_device_list", b"collective_device_list", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_composite", b"is_composite", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "k", b"k", "largest", b"largest", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "original_value", b"original_value", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "output_operand_aliasing", b"output_operand_aliasing", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index"]) -> typing.Literal["cross_program_prefetch_index"] | None: ...
global___HloInstructionProto = HloInstructionProto
@@ -980,6 +998,7 @@ class HloModuleProto(google.protobuf.message.Message):
FUSION: HloModuleProto._ProfileType.ValueType # 2
LAYOUT: HloModuleProto._ProfileType.ValueType # 3
DOT: HloModuleProto._ProfileType.ValueType # 4
FLAGNET: HloModuleProto._ProfileType.ValueType # 5
class ProfileType(_ProfileType, metaclass=_ProfileTypeEnumTypeWrapper):
"""The type of optimization profile in use for module-level optimizations."""
@@ -989,6 +1008,7 @@ class HloModuleProto(google.protobuf.message.Message):
FUSION: HloModuleProto.ProfileType.ValueType # 2
LAYOUT: HloModuleProto.ProfileType.ValueType # 3
DOT: HloModuleProto.ProfileType.ValueType # 4
FLAGNET: HloModuleProto.ProfileType.ValueType # 5
@typing.final
class ProfileInfo(google.protobuf.message.Message):
@@ -1604,35 +1624,3 @@ class HloPassMetadata(google.protobuf.message.Message):
def ClearField(self, field_name: typing.Literal["custom_metadata", b"custom_metadata", "dump_filenames", b"dump_filenames", "end_timestamp_usec", b"end_timestamp_usec", "module_changed", b"module_changed", "module_group_module_ids", b"module_group_module_ids", "module_id", b"module_id", "pass_id", b"pass_id", "pass_name", b"pass_name", "pipeline_name", b"pipeline_name", "start_timestamp_usec", b"start_timestamp_usec"]) -> None: ...
global___HloPassMetadata = HloPassMetadata
@typing.final
class XlaRuntimeExecutableProto(google.protobuf.message.Message):
"""Encodes the underlying Xla runtime executable compiled from the XLA module."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
HLO_MODULE_PROTO_FIELD_NUMBER: builtins.int
OBJ_FILE_FIELD_NUMBER: builtins.int
MLIR_MODULE_FIELD_NUMBER: builtins.int
obj_file: builtins.bytes
"""TODO(b/232263665)): Serialized executable has to know what APIs it has to
be linked with, including the version. For example Gpu executable must be
linked with a runtime layer that abstracts over CUDA.
Serialized object file compiled from the XLA module.
"""
mlir_module: builtins.str
"""Serialized MLIR module corresponding to compiled object file."""
@property
def hlo_module_proto(self) -> global___HloModuleProto: ...
def __init__(
self,
*,
hlo_module_proto: global___HloModuleProto | None = ...,
obj_file: builtins.bytes | None = ...,
mlir_module: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["hlo_module_proto", b"hlo_module_proto"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["hlo_module_proto", b"hlo_module_proto", "mlir_module", b"mlir_module", "obj_file", b"obj_file"]) -> None: ...
global___XlaRuntimeExecutableProto = XlaRuntimeExecutableProto

View File

@@ -61,6 +61,57 @@ class PassMetrics(google.protobuf.message.Message):
global___PassMetrics = PassMetrics
@typing.final
class JobInfo(google.protobuf.message.Message):
"""Defines compilation job information."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NAME_FIELD_NUMBER: builtins.int
CELL_FIELD_NUMBER: builtins.int
USER_FIELD_NUMBER: builtins.int
UID_FIELD_NUMBER: builtins.int
TASK_ID_FIELD_NUMBER: builtins.int
TASK_UID_FIELD_NUMBER: builtins.int
name: builtins.str
"""Name of the job running compilation."""
cell: builtins.str
"""Cell in which the job is running."""
user: builtins.str
"""User running the job."""
uid: builtins.int
"""Unique id when combined with user and cell field."""
task_id: builtins.int
"""Task index, which will not change across job restarts."""
task_uid: builtins.int
"""Task unique id, which may change across job restarts."""
def __init__(
self,
*,
name: builtins.str | None = ...,
cell: builtins.str | None = ...,
user: builtins.str | None = ...,
uid: builtins.int | None = ...,
task_id: builtins.int | None = ...,
task_uid: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_cell", b"_cell", "_name", b"_name", "_task_id", b"_task_id", "_task_uid", b"_task_uid", "_uid", b"_uid", "_user", b"_user", "cell", b"cell", "name", b"name", "task_id", b"task_id", "task_uid", b"task_uid", "uid", b"uid", "user", b"user"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_cell", b"_cell", "_name", b"_name", "_task_id", b"_task_id", "_task_uid", b"_task_uid", "_uid", b"_uid", "_user", b"_user", "cell", b"cell", "name", b"name", "task_id", b"task_id", "task_uid", b"task_uid", "uid", b"uid", "user", b"user"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_cell", b"_cell"]) -> typing.Literal["cell"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_name", b"_name"]) -> typing.Literal["name"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_task_id", b"_task_id"]) -> typing.Literal["task_id"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_task_uid", b"_task_uid"]) -> typing.Literal["task_uid"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_uid", b"_uid"]) -> typing.Literal["uid"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_user", b"_user"]) -> typing.Literal["user"] | None: ...
global___JobInfo = JobInfo
@typing.final
class CompilationLogEntry(google.protobuf.message.Message):
"""Defines XLA compilation metrics."""
@@ -94,11 +145,13 @@ class CompilationLogEntry(google.protobuf.message.Message):
TASK_INDEX_FIELD_NUMBER: builtins.int
PASS_METRICS_FIELD_NUMBER: builtins.int
MODULE_IDS_FIELD_NUMBER: builtins.int
JOB_INFO_FIELD_NUMBER: builtins.int
stage: global___CompilationLogEntry.CompilationStage.ValueType
"""Compilation stage recorded by this log entry."""
task_index: builtins.int
"""Task index from which this log entry was recorded or
-1 if the task index could not be fetched.
-1 if the task index could not be fetched. In the case task_index is not
equal to -1, it is guaranteed to match the task_id in job_info.
"""
@property
def timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp:
@@ -116,6 +169,10 @@ class CompilationLogEntry(google.protobuf.message.Message):
def module_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""IDs of modules on which the compilation stage was run."""
@property
def job_info(self) -> global___JobInfo:
"""Job information."""
def __init__(
self,
*,
@@ -125,8 +182,9 @@ class CompilationLogEntry(google.protobuf.message.Message):
task_index: builtins.int | None = ...,
pass_metrics: collections.abc.Iterable[global___PassMetrics] | None = ...,
module_ids: collections.abc.Iterable[builtins.int] | None = ...,
job_info: global___JobInfo | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["duration", b"duration", "timestamp", b"timestamp"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["duration", b"duration", "module_ids", b"module_ids", "pass_metrics", b"pass_metrics", "stage", b"stage", "task_index", b"task_index", "timestamp", b"timestamp"]) -> None: ...
def HasField(self, field_name: typing.Literal["duration", b"duration", "job_info", b"job_info", "timestamp", b"timestamp"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["duration", b"duration", "job_info", b"job_info", "module_ids", b"module_ids", "pass_metrics", b"pass_metrics", "stage", b"stage", "task_index", b"task_index", "timestamp", b"timestamp"]) -> None: ...
global___CompilationLogEntry = CompilationLogEntry

View File

@@ -115,7 +115,7 @@ class CompilationResult(google.protobuf.message.Message):
@property
def status(self) -> tensorflow.tsl.protobuf.status_pb2.StatusProto:
"""Always set when compilation fails; never set when compilation succeeds."""
"""Always set even when compilation succeeds."""
@property
def counters(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.int]:

View File

@@ -860,6 +860,7 @@ class OpMetadata(google.protobuf.message.Message):
DEDUPLICATED_NAME_FIELD_NUMBER: builtins.int
PRESERVE_LAYOUT_FIELD_NUMBER: builtins.int
STACK_FRAME_ID_FIELD_NUMBER: builtins.int
SCHEDULING_NAME_FIELD_NUMBER: builtins.int
op_type: builtins.str
"""The framework op name that generated this XLA op.
@@ -901,6 +902,8 @@ class OpMetadata(google.protobuf.message.Message):
"""1-based position of the frame in frames flat array.
Ids are 1-based to keep 0 value as representation of non-set property.
"""
scheduling_name: builtins.str
"""Instruction name available upon scheduling."""
@property
def profile_type(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[global___ProfileType.ValueType]:
"""Deprecated, use [ProfileInfo][profile_type] instead."""
@@ -923,9 +926,10 @@ class OpMetadata(google.protobuf.message.Message):
deduplicated_name: builtins.str | None = ...,
preserve_layout: builtins.bool | None = ...,
stack_frame_id: builtins.int | None = ...,
scheduling_name: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["profile_info", b"profile_info"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["deduplicated_name", b"deduplicated_name", "op_name", b"op_name", "op_type", b"op_type", "preserve_layout", b"preserve_layout", "profile_info", b"profile_info", "profile_type", b"profile_type", "size_of_generated_code_in_bytes", b"size_of_generated_code_in_bytes", "size_of_memory_working_set_in_bytes", b"size_of_memory_working_set_in_bytes", "source_file", b"source_file", "source_line", b"source_line", "stack_frame_id", b"stack_frame_id"]) -> None: ...
def ClearField(self, field_name: typing.Literal["deduplicated_name", b"deduplicated_name", "op_name", b"op_name", "op_type", b"op_type", "preserve_layout", b"preserve_layout", "profile_info", b"profile_info", "profile_type", b"profile_type", "scheduling_name", b"scheduling_name", "size_of_generated_code_in_bytes", b"size_of_generated_code_in_bytes", "size_of_memory_working_set_in_bytes", b"size_of_memory_working_set_in_bytes", "source_file", b"source_file", "source_line", b"source_line", "stack_frame_id", b"stack_frame_id"]) -> None: ...
global___OpMetadata = OpMetadata
@@ -1379,6 +1383,8 @@ class GatherDimensionNumbers(google.protobuf.message.Message):
COLLAPSED_SLICE_DIMS_FIELD_NUMBER: builtins.int
START_INDEX_MAP_FIELD_NUMBER: builtins.int
INDEX_VECTOR_DIM_FIELD_NUMBER: builtins.int
OPERAND_BATCHING_DIMS_FIELD_NUMBER: builtins.int
START_INDICES_BATCHING_DIMS_FIELD_NUMBER: builtins.int
index_vector_dim: builtins.int
"""The dimension in the start_indices input that contains the starting
indices.
@@ -1409,6 +1415,16 @@ class GatherDimensionNumbers(google.protobuf.message.Message):
the starting index in the input space.
"""
@property
def operand_batching_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""This is the batch dimensions in the operand."""
@property
def start_indices_batching_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""This is the batch dimensions in the index, and it should be the same size
as operand_batching_dims.
"""
def __init__(
self,
*,
@@ -1416,8 +1432,10 @@ class GatherDimensionNumbers(google.protobuf.message.Message):
collapsed_slice_dims: collections.abc.Iterable[builtins.int] | None = ...,
start_index_map: collections.abc.Iterable[builtins.int] | None = ...,
index_vector_dim: builtins.int | None = ...,
operand_batching_dims: collections.abc.Iterable[builtins.int] | None = ...,
start_indices_batching_dims: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["collapsed_slice_dims", b"collapsed_slice_dims", "index_vector_dim", b"index_vector_dim", "offset_dims", b"offset_dims", "start_index_map", b"start_index_map"]) -> None: ...
def ClearField(self, field_name: typing.Literal["collapsed_slice_dims", b"collapsed_slice_dims", "index_vector_dim", b"index_vector_dim", "offset_dims", b"offset_dims", "operand_batching_dims", b"operand_batching_dims", "start_index_map", b"start_index_map", "start_indices_batching_dims", b"start_indices_batching_dims"]) -> None: ...
global___GatherDimensionNumbers = GatherDimensionNumbers
@@ -1435,6 +1453,8 @@ class ScatterDimensionNumbers(google.protobuf.message.Message):
INSERTED_WINDOW_DIMS_FIELD_NUMBER: builtins.int
SCATTER_DIMS_TO_OPERAND_DIMS_FIELD_NUMBER: builtins.int
INDEX_VECTOR_DIM_FIELD_NUMBER: builtins.int
INPUT_BATCHING_DIMS_FIELD_NUMBER: builtins.int
SCATTER_INDICES_BATCHING_DIMS_FIELD_NUMBER: builtins.int
index_vector_dim: builtins.int
@property
def update_window_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
@@ -1446,6 +1466,14 @@ class ScatterDimensionNumbers(google.protobuf.message.Message):
@property
def scatter_dims_to_operand_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
@property
def input_batching_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""This is the batch dimension in the input."""
@property
def scatter_indices_batching_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""This is the batch dimension in the index."""
def __init__(
self,
*,
@@ -1453,8 +1481,10 @@ class ScatterDimensionNumbers(google.protobuf.message.Message):
inserted_window_dims: collections.abc.Iterable[builtins.int] | None = ...,
scatter_dims_to_operand_dims: collections.abc.Iterable[builtins.int] | None = ...,
index_vector_dim: builtins.int | None = ...,
input_batching_dims: collections.abc.Iterable[builtins.int] | None = ...,
scatter_indices_batching_dims: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["index_vector_dim", b"index_vector_dim", "inserted_window_dims", b"inserted_window_dims", "scatter_dims_to_operand_dims", b"scatter_dims_to_operand_dims", "update_window_dims", b"update_window_dims"]) -> None: ...
def ClearField(self, field_name: typing.Literal["index_vector_dim", b"index_vector_dim", "input_batching_dims", b"input_batching_dims", "inserted_window_dims", b"inserted_window_dims", "scatter_dims_to_operand_dims", b"scatter_dims_to_operand_dims", "scatter_indices_batching_dims", b"scatter_indices_batching_dims", "update_window_dims", b"update_window_dims"]) -> None: ...
global___ScatterDimensionNumbers = ScatterDimensionNumbers
@@ -1981,6 +2011,80 @@ class ReplicaGroup(google.protobuf.message.Message):
global___ReplicaGroup = ReplicaGroup
@typing.final
class IotaReplicaGroupListProto(google.protobuf.message.Message):
"""Represents a list of replica groups (a list of list of devices) with
reshaping and transposing an iota array (iota tile assignment). Can be used
to represent certain common patterns of device lists in a compact, scalable
format.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NUM_REPLICA_GROUPS_FIELD_NUMBER: builtins.int
NUM_DEVICES_PER_GROUP_FIELD_NUMBER: builtins.int
IOTA_RESHAPE_DIMS_FIELD_NUMBER: builtins.int
IOTA_TRANSPOSE_PERM_FIELD_NUMBER: builtins.int
num_replica_groups: builtins.int
"""Number of replica groups."""
num_devices_per_group: builtins.int
"""Number of devices per group."""
@property
def iota_reshape_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""The dimensions used to reshape the 1D iota array of device IDs."""
@property
def iota_transpose_perm(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""The dimension permutations to transposed the iota array reshaped to
iota_reshape_dims. This must have the same size as iota_reshape_dims.
"""
def __init__(
self,
*,
num_replica_groups: builtins.int | None = ...,
num_devices_per_group: builtins.int | None = ...,
iota_reshape_dims: collections.abc.Iterable[builtins.int] | None = ...,
iota_transpose_perm: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["iota_reshape_dims", b"iota_reshape_dims", "iota_transpose_perm", b"iota_transpose_perm", "num_devices_per_group", b"num_devices_per_group", "num_replica_groups", b"num_replica_groups"]) -> None: ...
global___IotaReplicaGroupListProto = IotaReplicaGroupListProto
@typing.final
class CollectiveDeviceListProto(google.protobuf.message.Message):
"""Represents a series of devices participating in a collective operation (e.g.,
all-reduce and all-to-all). While this directly translates to a list of
replica groups, it may be used to represent these lists in a compact form.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
REPLICA_GROUPS_FIELD_NUMBER: builtins.int
IOTA_REPLICA_GROUP_LIST_FIELD_NUMBER: builtins.int
@property
def replica_groups(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ReplicaGroup]:
"""ReplicaGroupV1: List of replica groups. Legacy way of representing device
lists.
"""
@property
def iota_replica_group_list(self) -> global___IotaReplicaGroupListProto:
"""ReplicaGroupV2: Represents a list of replica groups with reshaping and
transposing an iota array.
"""
def __init__(
self,
*,
replica_groups: collections.abc.Iterable[global___ReplicaGroup] | None = ...,
iota_replica_group_list: global___IotaReplicaGroupListProto | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["iota_replica_group_list", b"iota_replica_group_list"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["iota_replica_group_list", b"iota_replica_group_list", "replica_groups", b"replica_groups"]) -> None: ...
global___CollectiveDeviceListProto = CollectiveDeviceListProto
@typing.final
class SourceTarget(google.protobuf.message.Message):
"""Describes the source target pair in the collective permute op."""
@@ -2236,3 +2340,42 @@ class OutputOperandAliasing(google.protobuf.message.Message):
def ClearField(self, field_name: typing.Literal["operand_index", b"operand_index", "operand_shape_index", b"operand_shape_index", "output_shape_index", b"output_shape_index"]) -> None: ...
global___OutputOperandAliasing = OutputOperandAliasing
@typing.final
class OriginalArrayProto(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
LEAF_SHAPE_INDEX_FIELD_NUMBER: builtins.int
INSTRUCTION_NAME_FIELD_NUMBER: builtins.int
SHAPE_INDEX_FIELD_NUMBER: builtins.int
instruction_name: builtins.str
@property
def leaf_shape_index(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
@property
def shape_index(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
def __init__(
self,
*,
leaf_shape_index: collections.abc.Iterable[builtins.int] | None = ...,
instruction_name: builtins.str | None = ...,
shape_index: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["instruction_name", b"instruction_name", "leaf_shape_index", b"leaf_shape_index", "shape_index", b"shape_index"]) -> None: ...
global___OriginalArrayProto = OriginalArrayProto
@typing.final
class OriginalValueProto(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
LEAVES_FIELD_NUMBER: builtins.int
@property
def leaves(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___OriginalArrayProto]: ...
def __init__(
self,
*,
leaves: collections.abc.Iterable[global___OriginalArrayProto] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["leaves", b"leaves"]) -> None: ...
global___OriginalValueProto = OriginalValueProto

File diff suppressed because one or more lines are too long

View File

@@ -278,6 +278,25 @@ class OptimizationOptions(google.protobuf.message.Message):
global___OptimizationOptions = OptimizationOptions
@typing.final
class ServiceOptions(google.protobuf.message.Message):
"""next: 2"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
PINNED_FIELD_NUMBER: builtins.int
pinned: builtins.bool
def __init__(
self,
*,
pinned: builtins.bool | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["optional_pinned", b"optional_pinned", "pinned", b"pinned"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["optional_pinned", b"optional_pinned", "pinned", b"pinned"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["optional_pinned", b"optional_pinned"]) -> typing.Literal["pinned"] | None: ...
global___ServiceOptions = ServiceOptions
@typing.final
class ThreadingOptions(google.protobuf.message.Message):
"""next: 3"""
@@ -308,7 +327,7 @@ class Options(google.protobuf.message.Message):
"""Message stored with Dataset objects to control how datasets are processed and
optimized.
next: 12
next: 13
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -319,6 +338,7 @@ class Options(google.protobuf.message.Message):
AUTOTUNE_OPTIONS_FIELD_NUMBER: builtins.int
DISTRIBUTE_OPTIONS_FIELD_NUMBER: builtins.int
OPTIMIZATION_OPTIONS_FIELD_NUMBER: builtins.int
SERVICE_OPTIONS_FIELD_NUMBER: builtins.int
SLACK_FIELD_NUMBER: builtins.int
THREADING_OPTIONS_FIELD_NUMBER: builtins.int
EXTERNAL_STATE_POLICY_FIELD_NUMBER: builtins.int
@@ -346,6 +366,10 @@ class Options(google.protobuf.message.Message):
def optimization_options(self) -> global___OptimizationOptions:
"""The optimization options associated with the dataset."""
@property
def service_options(self) -> global___ServiceOptions:
"""The tf.data service options associated with the dataset."""
@property
def threading_options(self) -> global___ThreadingOptions:
"""The threading options associated with the dataset."""
@@ -359,14 +383,15 @@ class Options(google.protobuf.message.Message):
autotune_options: global___AutotuneOptions | None = ...,
distribute_options: global___DistributeOptions | None = ...,
optimization_options: global___OptimizationOptions | None = ...,
service_options: global___ServiceOptions | None = ...,
slack: builtins.bool | None = ...,
threading_options: global___ThreadingOptions | None = ...,
external_state_policy: global___ExternalStatePolicy.ValueType | None = ...,
symbolic_checkpoint: builtins.bool | None = ...,
warm_start: builtins.bool | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "dataset_name", b"dataset_name", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_dataset_name", b"optional_dataset_name", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "optional_warm_start", b"optional_warm_start", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options", "warm_start", b"warm_start"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "dataset_name", b"dataset_name", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "framework_type", b"framework_type", "optimization_options", b"optimization_options", "optional_dataset_name", b"optional_dataset_name", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "optional_warm_start", b"optional_warm_start", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options", "warm_start", b"warm_start"]) -> None: ...
def HasField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "dataset_name", b"dataset_name", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_dataset_name", b"optional_dataset_name", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "optional_warm_start", b"optional_warm_start", "service_options", b"service_options", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options", "warm_start", b"warm_start"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "dataset_name", b"dataset_name", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "framework_type", b"framework_type", "optimization_options", b"optimization_options", "optional_dataset_name", b"optional_dataset_name", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "optional_warm_start", b"optional_warm_start", "service_options", b"service_options", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options", "warm_start", b"warm_start"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["optional_dataset_name", b"optional_dataset_name"]) -> typing.Literal["dataset_name"] | None: ...
@typing.overload

View File

@@ -4,7 +4,7 @@ isort:skip_file
"""
import google.protobuf.descriptor
from tensorflow.tsl.protobuf.bfc_memory_map_pb2 import (
from tensorflow.compiler.xla.tsl.protobuf.bfc_memory_map_pb2 import (
BinSummary as BinSummary,
MemAllocatorStats as MemAllocatorStats,
MemChunk as MemChunk,

View File

@@ -773,6 +773,7 @@ class ConfigProto(google.protobuf.message.Message):
XLA_FUSION_AUTOTUNER_THRESH_FIELD_NUMBER: builtins.int
USE_TFRT_FIELD_NUMBER: builtins.int
ENABLE_MULTI_HOST_FIELD_NUMBER: builtins.int
TFRT_USE_IFRT_FIELD_NUMBER: builtins.int
BACKEND_SERVER_PORT_FIELD_NUMBER: builtins.int
TARGET_TPU_FIELD_NUMBER: builtins.int
TARGET_GPU_FIELD_NUMBER: builtins.int
@@ -886,6 +887,10 @@ class ConfigProto(google.protobuf.message.Message):
"""Whether runtime execution uses TFRT."""
enable_multi_host: builtins.bool
"""If true, use Pathways with TFRT API for multi host support."""
tfrt_use_ifrt: builtins.bool
"""If true, use ifrt as the backend for TFRT. This is only used when
`use_tfrt` is true.
"""
backend_server_port: builtins.int
"""Port for the Pathways server. Ignored if enable_multi_host=false."""
target_tpu: builtins.bool
@@ -962,6 +967,7 @@ class ConfigProto(google.protobuf.message.Message):
xla_fusion_autotuner_thresh: builtins.int | None = ...,
use_tfrt: builtins.bool | None = ...,
enable_multi_host: builtins.bool | None = ...,
tfrt_use_ifrt: builtins.bool | None = ...,
backend_server_port: builtins.int | None = ...,
target_tpu: builtins.bool | None = ...,
target_gpu: builtins.bool | None = ...,
@@ -973,7 +979,7 @@ class ConfigProto(google.protobuf.message.Message):
disable_eager_executor_streaming_enqueue: builtins.bool | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["coordination_config", b"coordination_config", "session_metadata", b"session_metadata"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["backend_server_port", b"backend_server_port", "collective_deterministic_sequential_execution", b"collective_deterministic_sequential_execution", "collective_group_leader", b"collective_group_leader", "collective_nccl", b"collective_nccl", "coordination_config", b"coordination_config", "disable_eager_executor_streaming_enqueue", b"disable_eager_executor_streaming_enqueue", "disable_functional_ops_lowering", b"disable_functional_ops_lowering", "disable_optimize_for_static_graph", b"disable_optimize_for_static_graph", "disable_output_partition_graphs", b"disable_output_partition_graphs", "disable_thread_spinning", b"disable_thread_spinning", "enable_mlir_bridge", b"enable_mlir_bridge", "enable_mlir_graph_optimization", b"enable_mlir_graph_optimization", "enable_multi_host", b"enable_multi_host", "executor_type", b"executor_type", "mlir_bridge_rollout", b"mlir_bridge_rollout", "optimize_for_static_graph", b"optimize_for_static_graph", "recv_buf_max_chunk", b"recv_buf_max_chunk", "session_metadata", b"session_metadata", "share_cluster_devices_in_session", b"share_cluster_devices_in_session", "share_session_state_in_clusterspec_propagation", b"share_session_state_in_clusterspec_propagation", "stream_merge_threshold", b"stream_merge_threshold", "target_gpu", b"target_gpu", "target_tpu", b"target_tpu", "use_numa_affinity", b"use_numa_affinity", "use_tfrt", b"use_tfrt", "xla_fusion_autotuner_thresh", b"xla_fusion_autotuner_thresh", "xla_prefer_single_graph_cluster", b"xla_prefer_single_graph_cluster"]) -> None: ...
def ClearField(self, field_name: typing.Literal["backend_server_port", b"backend_server_port", "collective_deterministic_sequential_execution", b"collective_deterministic_sequential_execution", "collective_group_leader", b"collective_group_leader", "collective_nccl", b"collective_nccl", "coordination_config", b"coordination_config", "disable_eager_executor_streaming_enqueue", b"disable_eager_executor_streaming_enqueue", "disable_functional_ops_lowering", b"disable_functional_ops_lowering", "disable_optimize_for_static_graph", b"disable_optimize_for_static_graph", "disable_output_partition_graphs", b"disable_output_partition_graphs", "disable_thread_spinning", b"disable_thread_spinning", "enable_mlir_bridge", b"enable_mlir_bridge", "enable_mlir_graph_optimization", b"enable_mlir_graph_optimization", "enable_multi_host", b"enable_multi_host", "executor_type", b"executor_type", "mlir_bridge_rollout", b"mlir_bridge_rollout", "optimize_for_static_graph", b"optimize_for_static_graph", "recv_buf_max_chunk", b"recv_buf_max_chunk", "session_metadata", b"session_metadata", "share_cluster_devices_in_session", b"share_cluster_devices_in_session", "share_session_state_in_clusterspec_propagation", b"share_session_state_in_clusterspec_propagation", "stream_merge_threshold", b"stream_merge_threshold", "target_gpu", b"target_gpu", "target_tpu", b"target_tpu", "tfrt_use_ifrt", b"tfrt_use_ifrt", "use_numa_affinity", b"use_numa_affinity", "use_tfrt", b"use_tfrt", "xla_fusion_autotuner_thresh", b"xla_fusion_autotuner_thresh", "xla_prefer_single_graph_cluster", b"xla_prefer_single_graph_cluster"]) -> None: ...
DEVICE_COUNT_FIELD_NUMBER: builtins.int
INTRA_OP_PARALLELISM_THREADS_FIELD_NUMBER: builtins.int
@@ -983,6 +989,7 @@ class ConfigProto(google.protobuf.message.Message):
PLACEMENT_PERIOD_FIELD_NUMBER: builtins.int
DEVICE_FILTERS_FIELD_NUMBER: builtins.int
GPU_OPTIONS_FIELD_NUMBER: builtins.int
PLUGGABLE_DEVICE_OPTIONS_FIELD_NUMBER: builtins.int
ALLOW_SOFT_PLACEMENT_FIELD_NUMBER: builtins.int
LOG_DEVICE_PLACEMENT_FIELD_NUMBER: builtins.int
GRAPH_OPTIONS_FIELD_NUMBER: builtins.int
@@ -1104,6 +1111,10 @@ class ConfigProto(google.protobuf.message.Message):
def gpu_options(self) -> global___GPUOptions:
"""Options that apply to all GPUs."""
@property
def pluggable_device_options(self) -> global___GPUOptions:
"""Options that apply to pluggable devices."""
@property
def graph_options(self) -> global___GraphOptions:
"""Options that apply to all graphs."""
@@ -1129,6 +1140,7 @@ class ConfigProto(google.protobuf.message.Message):
placement_period: builtins.int | None = ...,
device_filters: collections.abc.Iterable[builtins.str] | None = ...,
gpu_options: global___GPUOptions | None = ...,
pluggable_device_options: global___GPUOptions | None = ...,
allow_soft_placement: builtins.bool | None = ...,
log_device_placement: builtins.bool | None = ...,
graph_options: global___GraphOptions | None = ...,
@@ -1139,8 +1151,8 @@ class ConfigProto(google.protobuf.message.Message):
share_cluster_devices_in_session: builtins.bool | None = ...,
experimental: global___ConfigProto.Experimental | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["cluster_def", b"cluster_def", "experimental", b"experimental", "gpu_options", b"gpu_options", "graph_options", b"graph_options", "rpc_options", b"rpc_options"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["allow_soft_placement", b"allow_soft_placement", "cluster_def", b"cluster_def", "device_count", b"device_count", "device_filters", b"device_filters", "experimental", b"experimental", "gpu_options", b"gpu_options", "graph_options", b"graph_options", "inter_op_parallelism_threads", b"inter_op_parallelism_threads", "intra_op_parallelism_threads", b"intra_op_parallelism_threads", "isolate_session_state", b"isolate_session_state", "log_device_placement", b"log_device_placement", "operation_timeout_in_ms", b"operation_timeout_in_ms", "placement_period", b"placement_period", "rpc_options", b"rpc_options", "session_inter_op_thread_pool", b"session_inter_op_thread_pool", "share_cluster_devices_in_session", b"share_cluster_devices_in_session", "use_per_session_threads", b"use_per_session_threads"]) -> None: ...
def HasField(self, field_name: typing.Literal["cluster_def", b"cluster_def", "experimental", b"experimental", "gpu_options", b"gpu_options", "graph_options", b"graph_options", "pluggable_device_options", b"pluggable_device_options", "rpc_options", b"rpc_options"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["allow_soft_placement", b"allow_soft_placement", "cluster_def", b"cluster_def", "device_count", b"device_count", "device_filters", b"device_filters", "experimental", b"experimental", "gpu_options", b"gpu_options", "graph_options", b"graph_options", "inter_op_parallelism_threads", b"inter_op_parallelism_threads", "intra_op_parallelism_threads", b"intra_op_parallelism_threads", "isolate_session_state", b"isolate_session_state", "log_device_placement", b"log_device_placement", "operation_timeout_in_ms", b"operation_timeout_in_ms", "placement_period", b"placement_period", "pluggable_device_options", b"pluggable_device_options", "rpc_options", b"rpc_options", "session_inter_op_thread_pool", b"session_inter_op_thread_pool", "share_cluster_devices_in_session", b"share_cluster_devices_in_session", "use_per_session_threads", b"use_per_session_threads"]) -> None: ...
global___ConfigProto = ConfigProto

View File

@@ -527,61 +527,6 @@ global___TensorInfo = TensorInfo
class SignatureDef(google.protobuf.message.Message):
"""SignatureDef defines the signature of a computation supported by a TensorFlow
graph.
For example, a model with two loss computations, sharing a single input,
might have the following signature_def map, in a MetaGraphDef message.
Note that across the two SignatureDefs "loss_A" and "loss_B", the input key,
output key, and method_name are identical, and will be used by system(s) that
implement or rely upon this particular loss method. The output tensor names
differ, demonstrating how different outputs can exist for the same method.
signature_def {
key: "loss_A"
value {
inputs {
key: "input"
value {
name: "input:0"
dtype: DT_STRING
tensor_shape: ...
}
}
outputs {
key: "loss_output"
value {
name: "loss_output_A:0"
dtype: DT_FLOAT
tensor_shape: ...
}
}
method_name: "some/package/compute_loss"
}
...
}
signature_def {
key: "loss_B"
value {
inputs {
key: "input"
value {
name: "input:0"
dtype: DT_STRING
tensor_shape: ...
}
}
outputs {
key: "loss_output"
value {
name: "loss_output_B:0"
dtype: DT_FLOAT
tensor_shape: ...
}
}
method_name: "some/package/compute_loss"
}
...
}
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -645,14 +590,13 @@ class SignatureDef(google.protobuf.message.Message):
METHOD_NAME_FIELD_NUMBER: builtins.int
DEFAULTS_FIELD_NUMBER: builtins.int
method_name: builtins.str
"""Extensible method_name information enabling third-party users to mark a
SignatureDef as supporting a particular method. This enables producers and
consumers of SignatureDefs, e.g. a model definition library and a serving
library to have a clear hand-off regarding the semantics of a computation.
"""Deprecated: TensorFlow 2 always sets this to a fixed value;
open-source TF Serving stopped checking by default since release 2.4.
Note that multiple SignatureDefs in a single MetaGraphDef may have the same
method_name. This is commonly used to support multi-headed computation,
where a single graph computation may return multiple results.
In TensorFlow 1, the method_name enabled users to mark a SignatureDef as
supporting a particular method. Multiple SignatureDefs in a single
MetaGraphDef could have the same method_name (e.g., to support multi-headed
computation).
"""
@property
def inputs(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___TensorInfo]:

View File

@@ -117,7 +117,7 @@ global___DispatcherConfig = DispatcherConfig
@typing.final
class WorkerConfig(google.protobuf.message.Message):
"""Configuration for a tf.data service WorkerServer.
Next id: 13
Next id: 14
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -130,6 +130,7 @@ class WorkerConfig(google.protobuf.message.Message):
HEARTBEAT_INTERVAL_MS_FIELD_NUMBER: builtins.int
DISPATCHER_TIMEOUT_MS_FIELD_NUMBER: builtins.int
DATA_TRANSFER_PROTOCOL_FIELD_NUMBER: builtins.int
DATA_TRANSFER_PORT_FIELD_NUMBER: builtins.int
DATA_TRANSFER_ADDRESS_FIELD_NUMBER: builtins.int
CROSS_TRAINER_CACHE_SIZE_BYTES_FIELD_NUMBER: builtins.int
SNAPSHOT_MAX_CHUNK_SIZE_BYTES_FIELD_NUMBER: builtins.int
@@ -157,11 +158,21 @@ class WorkerConfig(google.protobuf.message.Message):
runtime.
"""
data_transfer_protocol: builtins.str
"""The protocol for the worker to use when transferring data to clients."""
"""If set, the name of an alternative data transfer protocol for which the
worker starts an additional server ("data transfer server"); the trainer
can then get data from this server. If not set, no such server is started,
and the trainer can only get data from the regular worker server over
`protocol`.
"""
data_transfer_port: builtins.int
"""If `data_transfer_protocol` is set, the port to which the data transfer
server binds. If set to `0`, the server binds to any available port.
"""
data_transfer_address: builtins.str
"""The data transfer address of the worker server. The substring "%port%", if
specified, will be replaced with the worker's bound port. This is useful
when the port is set to `0`.
"""If `data_transfer_protocol` is set, the address of the data transfer
server. The substring "%dts_port%" can be used to represent -- and is
replaced with -- the bound port of the data transfer server; this is useful
when `data_transfer_port` is set to `0`.
"""
cross_trainer_cache_size_bytes: builtins.int
"""Maximum size of the cross-trainer cache in bytes. If enabled, make sure
@@ -195,11 +206,12 @@ class WorkerConfig(google.protobuf.message.Message):
heartbeat_interval_ms: builtins.int | None = ...,
dispatcher_timeout_ms: builtins.int | None = ...,
data_transfer_protocol: builtins.str | None = ...,
data_transfer_port: builtins.int | None = ...,
data_transfer_address: builtins.str | None = ...,
cross_trainer_cache_size_bytes: builtins.int | None = ...,
snapshot_max_chunk_size_bytes: builtins.int | None = ...,
shutdown_quiet_period_ms: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["cross_trainer_cache_size_bytes", b"cross_trainer_cache_size_bytes", "data_transfer_address", b"data_transfer_address", "data_transfer_protocol", b"data_transfer_protocol", "dispatcher_address", b"dispatcher_address", "dispatcher_timeout_ms", b"dispatcher_timeout_ms", "heartbeat_interval_ms", b"heartbeat_interval_ms", "port", b"port", "protocol", b"protocol", "shutdown_quiet_period_ms", b"shutdown_quiet_period_ms", "snapshot_max_chunk_size_bytes", b"snapshot_max_chunk_size_bytes", "worker_address", b"worker_address", "worker_tags", b"worker_tags"]) -> None: ...
def ClearField(self, field_name: typing.Literal["cross_trainer_cache_size_bytes", b"cross_trainer_cache_size_bytes", "data_transfer_address", b"data_transfer_address", "data_transfer_port", b"data_transfer_port", "data_transfer_protocol", b"data_transfer_protocol", "dispatcher_address", b"dispatcher_address", "dispatcher_timeout_ms", b"dispatcher_timeout_ms", "heartbeat_interval_ms", b"heartbeat_interval_ms", "port", b"port", "protocol", b"protocol", "shutdown_quiet_period_ms", b"shutdown_quiet_period_ms", "snapshot_max_chunk_size_bytes", b"snapshot_max_chunk_size_bytes", "worker_address", b"worker_address", "worker_tags", b"worker_tags"]) -> None: ...
global___WorkerConfig = WorkerConfig

View File

@@ -4,7 +4,7 @@ isort:skip_file
Protocol messages for describing the results of benchmarks and unit tests."""
import google.protobuf.descriptor
from tensorflow.tsl.protobuf.test_log_pb2 import (
from tensorflow.compiler.xla.tsl.protobuf.test_log_pb2 import (
AvailableDeviceInfo as AvailableDeviceInfo,
BenchmarkEntries as BenchmarkEntries,
BenchmarkEntry as BenchmarkEntry,

View File

@@ -121,7 +121,9 @@ class Dataset(ABC, Generic[_T1]):
self,
map_func: Callable[..., _T2],
num_parallel_calls: int | None = None,
deterministic: None | bool = None,
deterministic: bool | None = None,
synchronous: bool | None = None,
use_unbounded_threadpool: bool = False,
name: str | None = None,
) -> Dataset[_T2]: ...
def options(self) -> Options: ...

View File

@@ -42,7 +42,6 @@ class CallbackList(Callback):
model: Model[Any, Any] | None = None,
**params: Any,
) -> None: ...
def append(self, callback: Callback) -> None: ...
def set_params(self, params: dict[str, Any]) -> None: ...
def set_model(self, model: Model[Any, Any]) -> None: ...
def on_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...

View File

@@ -31,6 +31,7 @@ class InputSpec:
axes: dict[int, int | None] | None = None,
allow_last_axis_squeeze: bool = False,
name: str | None = None,
optional: bool = False,
) -> None: ...
def get_config(self) -> dict[str, Any]: ...
@classmethod
@@ -384,6 +385,7 @@ class MultiHeadAttention(Layer[Any, tf.Tensor]):
activity_regularizer: _Regularizer | None = None,
kernel_constraint: _Constraint | None = None,
bias_constraint: _Constraint | None = None,
seed: int | None = None,
*,
# **kwargs passed to Layer
trainable: bool = True,

View File

@@ -34,6 +34,7 @@ class BinaryCrossentropy(Loss):
axis: int = -1,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "binary_crossentropy",
dtype: Incomplete | None = None,
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
@@ -48,6 +49,7 @@ class BinaryFocalCrossentropy(Loss):
axis: int = -1,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "binary_focal_crossentropy",
dtype: Incomplete | None = None,
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
@@ -59,59 +61,100 @@ class CategoricalCrossentropy(Loss):
axis: int = -1,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "categorical_crossentropy",
dtype: Incomplete | None = None,
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class CategoricalHinge(Loss):
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "categorical_hinge") -> None: ...
def __init__(
self,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "categorical_hinge",
dtype: Incomplete | None = None,
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class CosineSimilarity(Loss):
def __init__(
self, axis: int = -1, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "cosine_similarity"
self,
axis: int = -1,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "cosine_similarity",
dtype: Incomplete | None = None,
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class Hinge(Loss):
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "hinge") -> None: ...
def __init__(
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "hinge", dtype: Incomplete | None = None
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class Huber(Loss):
def __init__(
self, delta: float = 1.0, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "huber_loss"
self,
delta: float = 1.0,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "huber_loss",
dtype: Incomplete | None = None,
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class KLDivergence(Loss):
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "kl_divergence") -> None: ...
def __init__(
self,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "kl_divergence",
dtype: Incomplete | None = None,
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class LogCosh(Loss):
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "log_cosh") -> None: ...
def __init__(
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "log_cosh", dtype: Incomplete | None = None
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class MeanAbsoluteError(Loss):
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_absolute_error") -> None: ...
def __init__(
self,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "mean_absolute_error",
dtype: Incomplete | None = None,
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class MeanAbsolutePercentageError(Loss):
def __init__(
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_absolute_percentage_error"
self,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "mean_absolute_percentage_error",
dtype: Incomplete | None = None,
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class MeanSquaredError(Loss):
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_squared_error") -> None: ...
def __init__(
self,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "mean_squared_error",
dtype: Incomplete | None = None,
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class MeanSquaredLogarithmicError(Loss):
def __init__(
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_squared_logarithmic_error"
self,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "mean_squared_logarithmic_error",
dtype: Incomplete | None = None,
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class Poisson(Loss):
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "poisson") -> None: ...
def __init__(
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "poisson", dtype: Incomplete | None = None
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class SparseCategoricalCrossentropy(Loss):
@@ -121,11 +164,17 @@ class SparseCategoricalCrossentropy(Loss):
ignore_class: int | None = None,
reduction: _ReductionValues = "sum_over_batch_size",
name: str = "sparse_categorical_crossentropy",
dtype: Incomplete | None = None,
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class SquaredHinge(Loss):
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "squared_hinge") -> None: ...
def __init__(
self,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "squared_hinge",
dtype: Incomplete | None = None,
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class Reduction:

View File

@@ -66,6 +66,7 @@ class Model(Layer[_InputT, _OutputT]):
y: TensorCompatible | None = None,
y_pred: TensorCompatible | None = None,
sample_weight: Incomplete | None = None,
training: bool = True,
) -> tf.Tensor | None: ...
def compute_metrics(
self, x: TensorCompatible, y: TensorCompatible, y_pred: TensorCompatible, sample_weight: Incomplete | None = None
@@ -137,7 +138,7 @@ class Model(Layer[_InputT, _OutputT]):
@property
def non_trainable_weights(self) -> list[Variable]: ...
def get_weights(self): ...
def save(self, filepath: str | Path, overwrite: bool = True) -> None: ...
def save(self, filepath: str | Path, overwrite: bool = True, zipped: bool | None = None) -> None: ...
def save_weights(self, filepath: str | Path, overwrite: bool = True) -> None: ...
# kwargs are from keras.saving.saving_api.load_weights
def load_weights(self, filepath: str | Path, skip_mismatch: bool = False, *, by_name: bool = False) -> None: ...
@@ -161,6 +162,6 @@ class Model(Layer[_InputT, _OutputT]):
def get_layer(self, name: str | None = None, index: int | None = None) -> Layer[Incomplete, Incomplete]: ...
def get_compile_config(self) -> dict[str, Any]: ...
def compile_from_config(self, config: dict[str, Any]) -> Self: ...
def export(self, filepath: str | Path, format: str = "tf_saved_model") -> None: ...
def export(self, filepath: str | Path, format: str = "tf_saved_model", verbose: bool = True) -> None: ...
def __getattr__(name: str) -> Incomplete: ...

View File

@@ -54,6 +54,7 @@ class CoordinationServiceConfig(google.protobuf.message.Message):
RECOVERABLE_JOBS_FIELD_NUMBER: builtins.int
ALLOW_NEW_INCARNATION_TO_RECONNECT_FIELD_NUMBER: builtins.int
FORCE_DISABLE_FIELD_NUMBER: builtins.int
POLL_FOR_ERROR_FROM_SERVICE_AT_STARTUP_FIELD_NUMBER: builtins.int
service_type: builtins.str
"""Type of coordination service implementation to enable.
For example, setting the service type as "standalone" starts a service
@@ -95,6 +96,10 @@ class CoordinationServiceConfig(google.protobuf.message.Message):
not specify any config. This field allows users to explicitly disable
coordination service under all situations.
"""
poll_for_error_from_service_at_startup: builtins.bool
"""Use long polling to get error from coordination service as the error
propagation mechanism.
"""
@property
def coordinated_job_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___CoordinatedJob]: ...
@property
@@ -119,7 +124,8 @@ class CoordinationServiceConfig(google.protobuf.message.Message):
recoverable_jobs: collections.abc.Iterable[builtins.str] | None = ...,
allow_new_incarnation_to_reconnect: builtins.bool | None = ...,
force_disable: builtins.bool | None = ...,
poll_for_error_from_service_at_startup: builtins.bool | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["agent_destruction_without_shutdown", b"agent_destruction_without_shutdown", "allow_new_incarnation_to_reconnect", b"allow_new_incarnation_to_reconnect", "cluster_register_timeout_in_ms", b"cluster_register_timeout_in_ms", "coordinated_job_list", b"coordinated_job_list", "enable_health_check", b"enable_health_check", "force_disable", b"force_disable", "heartbeat_timeout_in_ms", b"heartbeat_timeout_in_ms", "recoverable_jobs", b"recoverable_jobs", "service_leader", b"service_leader", "service_type", b"service_type", "shutdown_barrier_timeout_in_ms", b"shutdown_barrier_timeout_in_ms"]) -> None: ...
def ClearField(self, field_name: typing.Literal["agent_destruction_without_shutdown", b"agent_destruction_without_shutdown", "allow_new_incarnation_to_reconnect", b"allow_new_incarnation_to_reconnect", "cluster_register_timeout_in_ms", b"cluster_register_timeout_in_ms", "coordinated_job_list", b"coordinated_job_list", "enable_health_check", b"enable_health_check", "force_disable", b"force_disable", "heartbeat_timeout_in_ms", b"heartbeat_timeout_in_ms", "poll_for_error_from_service_at_startup", b"poll_for_error_from_service_at_startup", "recoverable_jobs", b"recoverable_jobs", "service_leader", b"service_leader", "service_type", b"service_type", "shutdown_barrier_timeout_in_ms", b"shutdown_barrier_timeout_in_ms"]) -> None: ...
global___CoordinationServiceConfig = CoordinationServiceConfig

View File

@@ -234,6 +234,33 @@ class HeartbeatResponse(google.protobuf.message.Message):
global___HeartbeatResponse = HeartbeatResponse
@typing.final
class PollForErrorRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SOURCE_TASK_FIELD_NUMBER: builtins.int
@property
def source_task(self) -> global___CoordinatedTask: ...
def __init__(
self,
*,
source_task: global___CoordinatedTask | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["source_task", b"source_task"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["source_task", b"source_task"]) -> None: ...
global___PollForErrorRequest = PollForErrorRequest
@typing.final
class PollForErrorResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(
self,
) -> None: ...
global___PollForErrorResponse = PollForErrorResponse
@typing.final
class WaitForAllTasksRequest(google.protobuf.message.Message):
"""Request and response messages for waiting for all tasks."""