Bump tensorflow to 2.17.* (#12512)

This commit is contained in:
Avasam
2024-08-12 07:39:34 -04:00
committed by GitHub
parent dc0b63fd68
commit d3ca513ddb
15 changed files with 543 additions and 128 deletions

View File

@@ -7,7 +7,7 @@ set -euxo pipefail
# Whenever you update TENSORFLOW_VERSION here, version should be updated
# in stubs/tensorflow/METADATA.toml and vice-versa.
TENSORFLOW_VERSION=2.16.1
TENSORFLOW_VERSION=2.17.0
MYPY_PROTOBUF_VERSION=3.6.0
# brew install coreutils wget

View File

@@ -1,10 +1,10 @@
# Whenever you update version here, TENSORFLOW_VERSION should be updated
# in scripts/sync_tensorflow_protobuf_stubs.sh and vice-versa.
version = "2.16.*"
version = "2.17.*"
upstream_repository = "https://github.com/tensorflow/tensorflow"
# requires a version of numpy with a `py.typed` file
requires = ["numpy>=1.20", "types-protobuf", "types-requests"]
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 25.1 on tensorflow==2.16.1 ."
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 26.1 on tensorflow==2.17.0 ."
partial_stub = true
[tool.stubtest]

View File

@@ -546,7 +546,7 @@ class HloInstructionProto(google.protobuf.message.Message):
"""
@property
def dot_sparsity(self) -> tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor:
def dot_sparsity(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor]:
"""Sparsity descriptor for dot operation."""
def __init__(
@@ -622,9 +622,9 @@ class HloInstructionProto(google.protobuf.message.Message):
k: builtins.int | None = ...,
largest: builtins.bool | None = ...,
statistics_viz: tensorflow.compiler.xla.xla_data_pb2.StatisticsViz | None = ...,
dot_sparsity: tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor | None = ...,
dot_sparsity: collections.abc.Iterable[tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "k", b"k", "largest", b"largest", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "output_operand_aliasing", b"output_operand_aliasing", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index"]) -> typing.Literal["cross_program_prefetch_index"] | None: ...

View File

@@ -93,6 +93,7 @@ class CompilationLogEntry(google.protobuf.message.Message):
DURATION_FIELD_NUMBER: builtins.int
TASK_INDEX_FIELD_NUMBER: builtins.int
PASS_METRICS_FIELD_NUMBER: builtins.int
MODULE_IDS_FIELD_NUMBER: builtins.int
stage: global___CompilationLogEntry.CompilationStage.ValueType
"""Compilation stage recorded by this log entry."""
task_index: builtins.int
@@ -111,6 +112,10 @@ class CompilationLogEntry(google.protobuf.message.Message):
def pass_metrics(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PassMetrics]:
"""Pass specific metrics."""
@property
def module_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""IDs of modules on which the compilation stage was run."""
def __init__(
self,
*,
@@ -119,8 +124,9 @@ class CompilationLogEntry(google.protobuf.message.Message):
duration: google.protobuf.duration_pb2.Duration | None = ...,
task_index: builtins.int | None = ...,
pass_metrics: collections.abc.Iterable[global___PassMetrics] | None = ...,
module_ids: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["duration", b"duration", "timestamp", b"timestamp"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["duration", b"duration", "pass_metrics", b"pass_metrics", "stage", b"stage", "task_index", b"task_index", "timestamp", b"timestamp"]) -> None: ...
def ClearField(self, field_name: typing.Literal["duration", b"duration", "module_ids", b"module_ids", "pass_metrics", b"pass_metrics", "stage", b"stage", "task_index", b"task_index", "timestamp", b"timestamp"]) -> None: ...
global___CompilationLogEntry = CompilationLogEntry

View File

@@ -44,14 +44,16 @@ class _PrimitiveTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._
"""Invalid primitive type to serve as default."""
PRED: _PrimitiveType.ValueType # 1
"""Predicates are two-state booleans."""
S4: _PrimitiveType.ValueType # 21
S2: _PrimitiveType.ValueType # 26
"""Signed integral values of fixed width."""
S4: _PrimitiveType.ValueType # 21
S8: _PrimitiveType.ValueType # 2
S16: _PrimitiveType.ValueType # 3
S32: _PrimitiveType.ValueType # 4
S64: _PrimitiveType.ValueType # 5
U4: _PrimitiveType.ValueType # 22
U2: _PrimitiveType.ValueType # 27
"""Unsigned integral values of fixed width."""
U4: _PrimitiveType.ValueType # 22
U8: _PrimitiveType.ValueType # 6
U16: _PrimitiveType.ValueType # 7
U32: _PrimitiveType.ValueType # 8
@@ -149,14 +151,16 @@ PRIMITIVE_TYPE_INVALID: PrimitiveType.ValueType # 0
"""Invalid primitive type to serve as default."""
PRED: PrimitiveType.ValueType # 1
"""Predicates are two-state booleans."""
S4: PrimitiveType.ValueType # 21
S2: PrimitiveType.ValueType # 26
"""Signed integral values of fixed width."""
S4: PrimitiveType.ValueType # 21
S8: PrimitiveType.ValueType # 2
S16: PrimitiveType.ValueType # 3
S32: PrimitiveType.ValueType # 4
S64: PrimitiveType.ValueType # 5
U4: PrimitiveType.ValueType # 22
U2: PrimitiveType.ValueType # 27
"""Unsigned integral values of fixed width."""
U4: PrimitiveType.ValueType # 22
U8: PrimitiveType.ValueType # 6
U16: PrimitiveType.ValueType # 7
U32: PrimitiveType.ValueType # 8
@@ -508,8 +512,8 @@ global___PaddingConfig = PaddingConfig
@typing.final
class TileProto(google.protobuf.message.Message):
"""Describes a tile used in tiling-based layout. Refer to
g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for
details about tiling-based layout.
g3doc/third_party/xla/docs/tiled_layout.md for details about tiling-based
layout.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -532,6 +536,33 @@ class TileProto(google.protobuf.message.Message):
global___TileProto = TileProto
@typing.final
class SplitConfigProto(google.protobuf.message.Message):
"""Describes how data should be split between different memories."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DIMENSION_FIELD_NUMBER: builtins.int
SPLIT_INDICES_FIELD_NUMBER: builtins.int
dimension: builtins.int
"""The dimension that is split."""
@property
def split_indices(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""The indices where each split point occurs. For example, if the dimension
size is 1024, a split_indices value of {512} indicates a two-way split of
data through the middle.
"""
def __init__(
self,
*,
dimension: builtins.int | None = ...,
split_indices: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["dimension", b"dimension", "split_indices", b"split_indices"]) -> None: ...
global___SplitConfigProto = SplitConfigProto
@typing.final
class LayoutProto(google.protobuf.message.Message):
"""A layout describes how the array is placed in (1D) memory space. This
@@ -560,6 +591,7 @@ class LayoutProto(google.protobuf.message.Message):
POINTER_PRIMITIVE_TYPE_FIELD_NUMBER: builtins.int
PHYSICAL_SHAPE_FIELD_NUMBER: builtins.int
DYNAMIC_SHAPE_METADATA_PREFIX_BYTES_FIELD_NUMBER: builtins.int
SPLIT_CONFIGS_FIELD_NUMBER: builtins.int
tail_padding_alignment_in_elements: builtins.int
"""The shape is padded at the end to multiple of, in terms of number of
elements. This is useful when tiling does not bring the shape to certain
@@ -632,6 +664,12 @@ class LayoutProto(google.protobuf.message.Message):
a physical shape.
"""
@property
def split_configs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___SplitConfigProto]:
"""The split configurations which describe if/how the data is split between
different memories.
"""
def __init__(
self,
*,
@@ -647,9 +685,10 @@ class LayoutProto(google.protobuf.message.Message):
pointer_primitive_type: global___PrimitiveType.ValueType | None = ...,
physical_shape: global___ShapeProto | None = ...,
dynamic_shape_metadata_prefix_bytes: builtins.int | None = ...,
split_configs: collections.abc.Iterable[global___SplitConfigProto] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["physical_shape", b"physical_shape"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["dim_level_types", b"dim_level_types", "dim_ordered", b"dim_ordered", "dim_unique", b"dim_unique", "dynamic_shape_metadata_prefix_bytes", b"dynamic_shape_metadata_prefix_bytes", "element_size_in_bits", b"element_size_in_bits", "index_primitive_type", b"index_primitive_type", "memory_space", b"memory_space", "minor_to_major", b"minor_to_major", "physical_shape", b"physical_shape", "pointer_primitive_type", b"pointer_primitive_type", "tail_padding_alignment_in_elements", b"tail_padding_alignment_in_elements", "tiles", b"tiles"]) -> None: ...
def ClearField(self, field_name: typing.Literal["dim_level_types", b"dim_level_types", "dim_ordered", b"dim_ordered", "dim_unique", b"dim_unique", "dynamic_shape_metadata_prefix_bytes", b"dynamic_shape_metadata_prefix_bytes", "element_size_in_bits", b"element_size_in_bits", "index_primitive_type", b"index_primitive_type", "memory_space", b"memory_space", "minor_to_major", b"minor_to_major", "physical_shape", b"physical_shape", "pointer_primitive_type", b"pointer_primitive_type", "split_configs", b"split_configs", "tail_padding_alignment_in_elements", b"tail_padding_alignment_in_elements", "tiles", b"tiles"]) -> None: ...
global___LayoutProto = LayoutProto
@@ -815,8 +854,6 @@ class OpMetadata(google.protobuf.message.Message):
SOURCE_FILE_FIELD_NUMBER: builtins.int
SOURCE_LINE_FIELD_NUMBER: builtins.int
PROFILE_TYPE_FIELD_NUMBER: builtins.int
CREATION_PASS_ID_FIELD_NUMBER: builtins.int
LOGICAL_CREATION_PASS_ID_FIELD_NUMBER: builtins.int
SIZE_OF_GENERATED_CODE_IN_BYTES_FIELD_NUMBER: builtins.int
SIZE_OF_MEMORY_WORKING_SET_IN_BYTES_FIELD_NUMBER: builtins.int
PROFILE_INFO_FIELD_NUMBER: builtins.int
@@ -844,17 +881,6 @@ class OpMetadata(google.protobuf.message.Message):
e.g. it could be the file and line of user code that generated the op.
"""
source_line: builtins.int
creation_pass_id: builtins.int
"""HloPassMetadata.pass_id of the pass that created this HLO instruction
object. Should never be copied between HLO instructions. Zero if unset and
-1 if the instruction was created before HLO passes began.
"""
logical_creation_pass_id: builtins.int
"""HloPassMetadata.pass_id of the pass that created the logical functionality
that this HLO instruction represents. Should be copied between HLO
instructions that correspond across compilation passes. Zero if unset and
-1 if the instruction was created before HLO passes began.
"""
size_of_generated_code_in_bytes: builtins.int
"""The footprint of the generated code for the instruction."""
size_of_memory_working_set_in_bytes: builtins.int
@@ -891,8 +917,6 @@ class OpMetadata(google.protobuf.message.Message):
source_file: builtins.str | None = ...,
source_line: builtins.int | None = ...,
profile_type: collections.abc.Iterable[global___ProfileType.ValueType] | None = ...,
creation_pass_id: builtins.int | None = ...,
logical_creation_pass_id: builtins.int | None = ...,
size_of_generated_code_in_bytes: builtins.int | None = ...,
size_of_memory_working_set_in_bytes: builtins.int | None = ...,
profile_info: global___OpMetadata.ProfileInfo | None = ...,
@@ -901,7 +925,7 @@ class OpMetadata(google.protobuf.message.Message):
stack_frame_id: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["profile_info", b"profile_info"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["creation_pass_id", b"creation_pass_id", "deduplicated_name", b"deduplicated_name", "logical_creation_pass_id", b"logical_creation_pass_id", "op_name", b"op_name", "op_type", b"op_type", "preserve_layout", b"preserve_layout", "profile_info", b"profile_info", "profile_type", b"profile_type", "size_of_generated_code_in_bytes", b"size_of_generated_code_in_bytes", "size_of_memory_working_set_in_bytes", b"size_of_memory_working_set_in_bytes", "source_file", b"source_file", "source_line", b"source_line", "stack_frame_id", b"stack_frame_id"]) -> None: ...
def ClearField(self, field_name: typing.Literal["deduplicated_name", b"deduplicated_name", "op_name", b"op_name", "op_type", b"op_type", "preserve_layout", b"preserve_layout", "profile_info", b"profile_info", "profile_type", b"profile_type", "size_of_generated_code_in_bytes", b"size_of_generated_code_in_bytes", "size_of_memory_working_set_in_bytes", b"size_of_memory_working_set_in_bytes", "source_file", b"source_file", "source_line", b"source_line", "stack_frame_id", b"stack_frame_id"]) -> None: ...
global___OpMetadata = OpMetadata
@@ -918,6 +942,7 @@ class ExecutionProfile(google.protobuf.message.Message):
COMPUTE_AND_TRANSFER_TIME_NS_FIELD_NUMBER: builtins.int
EXECUTABLE_SIZE_IN_BYTES_FIELD_NUMBER: builtins.int
PROFILE_CACHE_HIT_FIELD_NUMBER: builtins.int
WARMUP_RUN_EXECUTED_FIELD_NUMBER: builtins.int
compilation_cache_hit: builtins.bool
"""Whether the executable was read from the compilation cache."""
compile_time_ms: builtins.int
@@ -944,6 +969,10 @@ class ExecutionProfile(google.protobuf.message.Message):
"""Whether this profile was drawn from a cache of profiles instead of from
execution on the hardware.
"""
warmup_run_executed: builtins.bool
"""Whether a warm-up run of the computation was executed before the
measured execution.
"""
def __init__(
self,
*,
@@ -954,8 +983,9 @@ class ExecutionProfile(google.protobuf.message.Message):
compute_and_transfer_time_ns: builtins.int | None = ...,
executable_size_in_bytes: builtins.int | None = ...,
profile_cache_hit: builtins.bool | None = ...,
warmup_run_executed: builtins.bool | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["compilation_cache_hit", b"compilation_cache_hit", "compile_time_ms", b"compile_time_ms", "compute_and_transfer_time_ns", b"compute_and_transfer_time_ns", "compute_cycle_count", b"compute_cycle_count", "compute_time_ns", b"compute_time_ns", "executable_size_in_bytes", b"executable_size_in_bytes", "profile_cache_hit", b"profile_cache_hit"]) -> None: ...
def ClearField(self, field_name: typing.Literal["compilation_cache_hit", b"compilation_cache_hit", "compile_time_ms", b"compile_time_ms", "compute_and_transfer_time_ns", b"compute_and_transfer_time_ns", "compute_cycle_count", b"compute_cycle_count", "compute_time_ns", b"compute_time_ns", "executable_size_in_bytes", b"executable_size_in_bytes", "profile_cache_hit", b"profile_cache_hit", "warmup_run_executed", b"warmup_run_executed"]) -> None: ...
global___ExecutionProfile = ExecutionProfile
@@ -1139,9 +1169,11 @@ class LiteralProto(google.protobuf.message.Message):
SHAPE_FIELD_NUMBER: builtins.int
PREDS_FIELD_NUMBER: builtins.int
S2S_FIELD_NUMBER: builtins.int
S4S_FIELD_NUMBER: builtins.int
U4S_FIELD_NUMBER: builtins.int
S8S_FIELD_NUMBER: builtins.int
U2S_FIELD_NUMBER: builtins.int
U4S_FIELD_NUMBER: builtins.int
U8S_FIELD_NUMBER: builtins.int
S32S_FIELD_NUMBER: builtins.int
S64S_FIELD_NUMBER: builtins.int
@@ -1162,9 +1194,11 @@ class LiteralProto(google.protobuf.message.Message):
F8E5M2FNUZS_FIELD_NUMBER: builtins.int
F8E4M3FNUZS_FIELD_NUMBER: builtins.int
SPARSE_INDICES_FIELD_NUMBER: builtins.int
s2s: builtins.bytes
s4s: builtins.bytes
u4s: builtins.bytes
s8s: builtins.bytes
u2s: builtins.bytes
u4s: builtins.bytes
u8s: builtins.bytes
f16s: builtins.bytes
"""The F16s, BF16s, U16s and S16s are encoded in little endian byte order"""
@@ -1204,16 +1238,18 @@ class LiteralProto(google.protobuf.message.Message):
def tuple_literals(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___LiteralProto]: ...
@property
def sparse_indices(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""Next = 26"""
"""Next = 28"""
def __init__(
self,
*,
shape: global___ShapeProto | None = ...,
preds: collections.abc.Iterable[builtins.bool] | None = ...,
s2s: builtins.bytes | None = ...,
s4s: builtins.bytes | None = ...,
u4s: builtins.bytes | None = ...,
s8s: builtins.bytes | None = ...,
u2s: builtins.bytes | None = ...,
u4s: builtins.bytes | None = ...,
u8s: builtins.bytes | None = ...,
s32s: collections.abc.Iterable[builtins.int] | None = ...,
s64s: collections.abc.Iterable[builtins.int] | None = ...,
@@ -1236,7 +1272,7 @@ class LiteralProto(google.protobuf.message.Message):
sparse_indices: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["shape", b"shape"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["bf16s", b"bf16s", "c128s", b"c128s", "c64s", b"c64s", "f16s", b"f16s", "f32s", b"f32s", "f64s", b"f64s", "f8e4m3b11fnuzs", b"f8e4m3b11fnuzs", "f8e4m3fns", b"f8e4m3fns", "f8e4m3fnuzs", b"f8e4m3fnuzs", "f8e5m2fnuzs", b"f8e5m2fnuzs", "f8e5m2s", b"f8e5m2s", "preds", b"preds", "s16s", b"s16s", "s32s", b"s32s", "s4s", b"s4s", "s64s", b"s64s", "s8s", b"s8s", "shape", b"shape", "sparse_indices", b"sparse_indices", "tuple_literals", b"tuple_literals", "u16s", b"u16s", "u32s", b"u32s", "u4s", b"u4s", "u64s", b"u64s", "u8s", b"u8s"]) -> None: ...
def ClearField(self, field_name: typing.Literal["bf16s", b"bf16s", "c128s", b"c128s", "c64s", b"c64s", "f16s", b"f16s", "f32s", b"f32s", "f64s", b"f64s", "f8e4m3b11fnuzs", b"f8e4m3b11fnuzs", "f8e4m3fns", b"f8e4m3fns", "f8e4m3fnuzs", b"f8e4m3fnuzs", "f8e5m2fnuzs", b"f8e5m2fnuzs", "f8e5m2s", b"f8e5m2s", "preds", b"preds", "s16s", b"s16s", "s2s", b"s2s", "s32s", b"s32s", "s4s", b"s4s", "s64s", b"s64s", "s8s", b"s8s", "shape", b"shape", "sparse_indices", b"sparse_indices", "tuple_literals", b"tuple_literals", "u16s", b"u16s", "u2s", b"u2s", "u32s", b"u32s", "u4s", b"u4s", "u64s", b"u64s", "u8s", b"u8s"]) -> None: ...
global___LiteralProto = LiteralProto
@@ -1992,15 +2028,113 @@ class PrecisionConfig(google.protobuf.message.Message):
PACKED_NIBBLE: PrecisionConfig.Precision.ValueType # 3
"""Each U8/S8 value in a tensor actually represents 2 nibble values."""
class _Algorithm:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _AlgorithmEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[PrecisionConfig._Algorithm.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
ALG_UNSET: PrecisionConfig._Algorithm.ValueType # 0
"""If the algorithm is `ALG_UNSET`, we will decide the algorithm based on
the operand_precision values (for now).
"""
ALG_DOT_ANY_F8_ANY_F8_F32: PrecisionConfig._Algorithm.ValueType # 1
"""The storage type can be any 8-bit floating point type."""
ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: PrecisionConfig._Algorithm.ValueType # 2
"""The storage type can be any 8-bit floating point type. Intermediate
results will not periodically be promoted to a higher precision. This
corresponds to CUBLASLT_MATMUL_DESC_FAST_ACCUM. Triton's
maxNumImpreciseAcc=32 setting may be similar.
"""
ALG_DOT_F16_F16_F16: PrecisionConfig._Algorithm.ValueType # 3
ALG_DOT_F16_F16_F32: PrecisionConfig._Algorithm.ValueType # 4
ALG_DOT_BF16_BF16_BF16: PrecisionConfig._Algorithm.ValueType # 5
ALG_DOT_BF16_BF16_F32: PrecisionConfig._Algorithm.ValueType # 6
ALG_DOT_BF16_BF16_F32_X3: PrecisionConfig._Algorithm.ValueType # 7
"""An algorithm which uses 3 BF16_BF16_F32 matmuls to achieve better
precision.
"""
ALG_DOT_BF16_BF16_F32_X6: PrecisionConfig._Algorithm.ValueType # 8
"""An algorithm which uses 6 BF16_BF16_F32 matmuls to achieve better
precision (similar to F32).
"""
ALG_DOT_TF32_TF32_F32: PrecisionConfig._Algorithm.ValueType # 9
ALG_DOT_TF32_TF32_F32_X3: PrecisionConfig._Algorithm.ValueType # 10
"""An algorithm which uses 3 TF32_TF32_F32 matmuls to achieve better
precision (similar to F32).
"""
ALG_DOT_F32_F32_F32: PrecisionConfig._Algorithm.ValueType # 11
ALG_DOT_F64_F64_F64: PrecisionConfig._Algorithm.ValueType # 12
class Algorithm(_Algorithm, metaclass=_AlgorithmEnumTypeWrapper):
"""The algorithm used to evaluate the instruction.
The naming convention for the dot instruction is
ALG_DOT_{A_TYPE}_{B_TYPE}_{ACCUM_TYPE}[_X{NUM_OPS}] where A_TYPE, B_TYPE
and ACCUM_TYPE correspond to the types in the "primitive dot operations"
(such as TensorCore operations) and NUM_OPS is the number of such
operations used per "primitive tile". When the NUM_OPS
field is skipped, it is assumed to be 1. The types mentioned in the name
are independent of the storage types.
In general ATYPE and BTYPE are the precisions that the LHS and RHS of the
operation are rounded to and ACCUMTYPE is the accumulation type. If a
backend does not support the given algorithm, an error is raised. The
Algorithm enum is intended to eventually replace the Precision enum.
"""
ALG_UNSET: PrecisionConfig.Algorithm.ValueType # 0
"""If the algorithm is `ALG_UNSET`, we will decide the algorithm based on
the operand_precision values (for now).
"""
ALG_DOT_ANY_F8_ANY_F8_F32: PrecisionConfig.Algorithm.ValueType # 1
"""The storage type can be any 8-bit floating point type."""
ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: PrecisionConfig.Algorithm.ValueType # 2
"""The storage type can be any 8-bit floating point type. Intermediate
results will not periodically be promoted to a higher precision. This
corresponds to CUBLASLT_MATMUL_DESC_FAST_ACCUM. Triton's
maxNumImpreciseAcc=32 setting may be similar.
"""
ALG_DOT_F16_F16_F16: PrecisionConfig.Algorithm.ValueType # 3
ALG_DOT_F16_F16_F32: PrecisionConfig.Algorithm.ValueType # 4
ALG_DOT_BF16_BF16_BF16: PrecisionConfig.Algorithm.ValueType # 5
ALG_DOT_BF16_BF16_F32: PrecisionConfig.Algorithm.ValueType # 6
ALG_DOT_BF16_BF16_F32_X3: PrecisionConfig.Algorithm.ValueType # 7
"""An algorithm which uses 3 BF16_BF16_F32 matmuls to achieve better
precision.
"""
ALG_DOT_BF16_BF16_F32_X6: PrecisionConfig.Algorithm.ValueType # 8
"""An algorithm which uses 6 BF16_BF16_F32 matmuls to achieve better
precision (similar to F32).
"""
ALG_DOT_TF32_TF32_F32: PrecisionConfig.Algorithm.ValueType # 9
ALG_DOT_TF32_TF32_F32_X3: PrecisionConfig.Algorithm.ValueType # 10
"""An algorithm which uses 3 TF32_TF32_F32 matmuls to achieve better
precision (similar to F32).
"""
ALG_DOT_F32_F32_F32: PrecisionConfig.Algorithm.ValueType # 11
ALG_DOT_F64_F64_F64: PrecisionConfig.Algorithm.ValueType # 12
OPERAND_PRECISION_FIELD_NUMBER: builtins.int
ALGORITHM_FIELD_NUMBER: builtins.int
algorithm: global___PrecisionConfig.Algorithm.ValueType
"""Currently doesn't do anything, but we plan to support it for dot and
possibly more instructions.
TODO(b/316147294): Support this on GPU and add this to StableHLO as well.
If this is set, then `operand_precision` should be set to DEFAULT and it
will be ignored.
"""
@property
def operand_precision(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[global___PrecisionConfig.Precision.ValueType]: ...
def __init__(
self,
*,
operand_precision: collections.abc.Iterable[global___PrecisionConfig.Precision.ValueType] | None = ...,
algorithm: global___PrecisionConfig.Algorithm.ValueType | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["operand_precision", b"operand_precision"]) -> None: ...
def ClearField(self, field_name: typing.Literal["algorithm", b"algorithm", "operand_precision", b"operand_precision"]) -> None: ...
global___PrecisionConfig = PrecisionConfig

File diff suppressed because one or more lines are too long

View File

@@ -93,7 +93,7 @@ global___ExternalStatePolicy = ExternalStatePolicy
@typing.final
class AutotuneOptions(google.protobuf.message.Message):
"""next: 5"""
"""next: 6"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -101,10 +101,12 @@ class AutotuneOptions(google.protobuf.message.Message):
CPU_BUDGET_FIELD_NUMBER: builtins.int
RAM_BUDGET_FIELD_NUMBER: builtins.int
AUTOTUNE_ALGORITHM_FIELD_NUMBER: builtins.int
INITIAL_PARALLELISM_FIELD_NUMBER: builtins.int
enabled: builtins.bool
cpu_budget: builtins.int
ram_budget: builtins.int
autotune_algorithm: tensorflow.core.framework.model_pb2.AutotuneAlgorithm.ValueType
initial_parallelism: builtins.int
def __init__(
self,
*,
@@ -112,9 +114,10 @@ class AutotuneOptions(google.protobuf.message.Message):
cpu_budget: builtins.int | None = ...,
ram_budget: builtins.int | None = ...,
autotune_algorithm: tensorflow.core.framework.model_pb2.AutotuneAlgorithm.ValueType | None = ...,
initial_parallelism: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["autotune_algorithm", b"autotune_algorithm", "cpu_budget", b"cpu_budget", "enabled", b"enabled", "optional_autotune_algorithm", b"optional_autotune_algorithm", "optional_cpu_budget", b"optional_cpu_budget", "optional_enabled", b"optional_enabled", "optional_ram_budget", b"optional_ram_budget", "ram_budget", b"ram_budget"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["autotune_algorithm", b"autotune_algorithm", "cpu_budget", b"cpu_budget", "enabled", b"enabled", "optional_autotune_algorithm", b"optional_autotune_algorithm", "optional_cpu_budget", b"optional_cpu_budget", "optional_enabled", b"optional_enabled", "optional_ram_budget", b"optional_ram_budget", "ram_budget", b"ram_budget"]) -> None: ...
def HasField(self, field_name: typing.Literal["autotune_algorithm", b"autotune_algorithm", "cpu_budget", b"cpu_budget", "enabled", b"enabled", "initial_parallelism", b"initial_parallelism", "optional_autotune_algorithm", b"optional_autotune_algorithm", "optional_cpu_budget", b"optional_cpu_budget", "optional_enabled", b"optional_enabled", "optional_initial_parallelism", b"optional_initial_parallelism", "optional_ram_budget", b"optional_ram_budget", "ram_budget", b"ram_budget"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["autotune_algorithm", b"autotune_algorithm", "cpu_budget", b"cpu_budget", "enabled", b"enabled", "initial_parallelism", b"initial_parallelism", "optional_autotune_algorithm", b"optional_autotune_algorithm", "optional_cpu_budget", b"optional_cpu_budget", "optional_enabled", b"optional_enabled", "optional_initial_parallelism", b"optional_initial_parallelism", "optional_ram_budget", b"optional_ram_budget", "ram_budget", b"ram_budget"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["optional_autotune_algorithm", b"optional_autotune_algorithm"]) -> typing.Literal["autotune_algorithm"] | None: ...
@typing.overload
@@ -122,6 +125,8 @@ class AutotuneOptions(google.protobuf.message.Message):
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["optional_enabled", b"optional_enabled"]) -> typing.Literal["enabled"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["optional_initial_parallelism", b"optional_initial_parallelism"]) -> typing.Literal["initial_parallelism"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["optional_ram_budget", b"optional_ram_budget"]) -> typing.Literal["ram_budget"] | None: ...
global___AutotuneOptions = AutotuneOptions

View File

@@ -33,8 +33,11 @@ class ResourceHandleProto(google.protobuf.message.Message):
DTYPE_FIELD_NUMBER: builtins.int
SHAPE_FIELD_NUMBER: builtins.int
dtype: tensorflow.core.framework.types_pb2.DataType.ValueType
"""Data type of the tensor."""
@property
def shape(self) -> tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto: ...
def shape(self) -> tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto:
"""Shape of the tensor."""
def __init__(
self,
*,

View File

@@ -41,6 +41,7 @@ class TensorProto(google.protobuf.message.Message):
UINT64_VAL_FIELD_NUMBER: builtins.int
FLOAT8_VAL_FIELD_NUMBER: builtins.int
dtype: tensorflow.core.framework.types_pb2.DataType.ValueType
"""Data type of the tensor."""
version_number: builtins.int
"""Only one of the representations below is set, one of "tensor_contents" and
the "xxx_val" attributes. We are not using oneof because as oneofs cannot

View File

@@ -90,6 +90,49 @@ class GPUOptions(google.protobuf.message.Message):
) -> None: ...
def ClearField(self, field_name: typing.Literal["device_ordinal", b"device_ordinal", "memory_limit_mb", b"memory_limit_mb", "priority", b"priority"]) -> None: ...
@typing.final
class StreamMergeOptions(google.protobuf.message.Message):
"""Whether to merge data transfer streams into the compute stream in the
same stream group. Stream merging helps reduce the overhead caused by
stream synchronization, especially when data transfers are frequent. For
example, setting "merge_host_to_device_stream = true" will make the
compute stream responsible for both computation and host to device memory
copy.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
MERGE_HOST_TO_DEVICE_STREAM_FIELD_NUMBER: builtins.int
MERGE_DEVICE_TO_HOST_STREAM_FIELD_NUMBER: builtins.int
MERGE_DEVICE_TO_DEVICE_STREAM_FIELD_NUMBER: builtins.int
merge_host_to_device_stream: builtins.bool
"""If true, the compute stream will be used for host_to_device copy as
well. It's no longer necessary to record an event before the copy to
let the copy stream wait for the compute stream to finish. There is
also no need to wait for the copy to complete before executing the
callback function.
"""
merge_device_to_host_stream: builtins.bool
"""If true, the compute stream will be used for device_to_host copy as
well. It's no longer necessary to record an event before the copy to
let the copy stream wait for the compute stream to finish.
"""
merge_device_to_device_stream: builtins.bool
"""If true, the compute stream will be used for device_to_device copy as
well. It's no longer necessary to record an event before the copy to
let the copy stream wait for the compute stream of the sending device
to finish. There is also no need to wait for the compute stream of the
receiving device to finish if the copy is within the same device.
"""
def __init__(
self,
*,
merge_host_to_device_stream: builtins.bool | None = ...,
merge_device_to_host_stream: builtins.bool | None = ...,
merge_device_to_device_stream: builtins.bool | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["merge_device_to_device_stream", b"merge_device_to_device_stream", "merge_device_to_host_stream", b"merge_device_to_host_stream", "merge_host_to_device_stream", b"merge_host_to_device_stream"]) -> None: ...
VIRTUAL_DEVICES_FIELD_NUMBER: builtins.int
NUM_VIRTUAL_DEVICES_PER_GPU_FIELD_NUMBER: builtins.int
USE_UNIFIED_MEMORY_FIELD_NUMBER: builtins.int
@@ -105,6 +148,9 @@ class GPUOptions(google.protobuf.message.Message):
GPU_HOST_MEM_LIMIT_IN_MB_FIELD_NUMBER: builtins.int
GPU_HOST_MEM_DISALLOW_GROWTH_FIELD_NUMBER: builtins.int
GPU_SYSTEM_MEMORY_SIZE_IN_MB_FIELD_NUMBER: builtins.int
POPULATE_PJRT_GPU_CLIENT_CREATION_INFO_FIELD_NUMBER: builtins.int
NODE_ID_FIELD_NUMBER: builtins.int
STREAM_MERGE_OPTIONS_FIELD_NUMBER: builtins.int
num_virtual_devices_per_gpu: builtins.int
"""The number of virtual devices to create on each visible GPU. The
available memory will be split equally among all virtual devices. If the
@@ -200,6 +246,14 @@ class GPUOptions(google.protobuf.message.Message):
system memory size for better resource estimation of multi-tenancy(one
gpu with multiple model) use case.
"""
populate_pjrt_gpu_client_creation_info: builtins.bool
"""If true, save information needed for created a PjRt GPU client for
creating a client with remote devices.
"""
node_id: builtins.int
"""node_id for use when creating a PjRt GPU client with remote devices,
which enumerates jobs*tasks from a ServerDef.
"""
@property
def virtual_devices(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___GPUOptions.Experimental.VirtualDevices]:
"""The multi virtual device settings. If empty (not set), it will create
@@ -242,6 +296,8 @@ class GPUOptions(google.protobuf.message.Message):
result in undefined behavior.
"""
@property
def stream_merge_options(self) -> global___GPUOptions.Experimental.StreamMergeOptions: ...
def __init__(
self,
*,
@@ -260,8 +316,12 @@ class GPUOptions(google.protobuf.message.Message):
gpu_host_mem_limit_in_mb: builtins.float | None = ...,
gpu_host_mem_disallow_growth: builtins.bool | None = ...,
gpu_system_memory_size_in_mb: builtins.int | None = ...,
populate_pjrt_gpu_client_creation_info: builtins.bool | None = ...,
node_id: builtins.int | None = ...,
stream_merge_options: global___GPUOptions.Experimental.StreamMergeOptions | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["collective_ring_order", b"collective_ring_order", "disallow_retry_on_allocation_failure", b"disallow_retry_on_allocation_failure", "gpu_host_mem_disallow_growth", b"gpu_host_mem_disallow_growth", "gpu_host_mem_limit_in_mb", b"gpu_host_mem_limit_in_mb", "gpu_system_memory_size_in_mb", b"gpu_system_memory_size_in_mb", "internal_fragmentation_fraction", b"internal_fragmentation_fraction", "kernel_tracker_max_bytes", b"kernel_tracker_max_bytes", "kernel_tracker_max_interval", b"kernel_tracker_max_interval", "kernel_tracker_max_pending", b"kernel_tracker_max_pending", "num_dev_to_dev_copy_streams", b"num_dev_to_dev_copy_streams", "num_virtual_devices_per_gpu", b"num_virtual_devices_per_gpu", "timestamped_allocator", b"timestamped_allocator", "use_cuda_malloc_async", b"use_cuda_malloc_async", "use_unified_memory", b"use_unified_memory", "virtual_devices", b"virtual_devices"]) -> None: ...
def HasField(self, field_name: typing.Literal["stream_merge_options", b"stream_merge_options"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["collective_ring_order", b"collective_ring_order", "disallow_retry_on_allocation_failure", b"disallow_retry_on_allocation_failure", "gpu_host_mem_disallow_growth", b"gpu_host_mem_disallow_growth", "gpu_host_mem_limit_in_mb", b"gpu_host_mem_limit_in_mb", "gpu_system_memory_size_in_mb", b"gpu_system_memory_size_in_mb", "internal_fragmentation_fraction", b"internal_fragmentation_fraction", "kernel_tracker_max_bytes", b"kernel_tracker_max_bytes", "kernel_tracker_max_interval", b"kernel_tracker_max_interval", "kernel_tracker_max_pending", b"kernel_tracker_max_pending", "node_id", b"node_id", "num_dev_to_dev_copy_streams", b"num_dev_to_dev_copy_streams", "num_virtual_devices_per_gpu", b"num_virtual_devices_per_gpu", "populate_pjrt_gpu_client_creation_info", b"populate_pjrt_gpu_client_creation_info", "stream_merge_options", b"stream_merge_options", "timestamped_allocator", b"timestamped_allocator", "use_cuda_malloc_async", b"use_cuda_malloc_async", "use_unified_memory", b"use_unified_memory", "virtual_devices", b"virtual_devices"]) -> None: ...
PER_PROCESS_GPU_MEMORY_FRACTION_FIELD_NUMBER: builtins.int
ALLOW_GROWTH_FIELD_NUMBER: builtins.int

View File

@@ -333,8 +333,8 @@ class RewriterConfig(google.protobuf.message.Message):
(default is ON).
"""
auto_mixed_precision: global___RewriterConfig.Toggle.ValueType
"""Optimize data types for CUDA (default is OFF).
This will try to use float16 on GPU which is faster.
"""Optimize data types for CUDA/oneDNN (default is OFF).
This will try to use float16 on GPU/CPU which is faster.
Note that this can change the numerical stability of the graph and may
require the use of loss scaling to maintain model convergence.
"""

View File

@@ -114,6 +114,7 @@ class Dataset(ABC, Generic[_T1]):
element_spec: ContainerGeneric[tf.TypeSpec[Any]] | None = None,
compression: _CompressionTypes = None,
reader_func: Callable[[Dataset[Dataset[Any]]], Dataset[Any]] | None = None,
wait: bool = False,
) -> Dataset[Any]: ...
# PEP 646 could be used here for a more precise type when better supported.
def map(
@@ -196,7 +197,7 @@ class Dataset(ABC, Generic[_T1]):
self,
buffer_size: ScalarTensorCompatible,
seed: int | None = None,
reshuffle_each_iteration: bool | None = None,
reshuffle_each_iteration: bool = True,
name: str | None = None,
) -> Dataset[_T1]: ...
def skip(self, count: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1]: ...

View File

@@ -47,15 +47,19 @@ class SaveOptions:
"experimental_image_format",
"experimental_skip_saver",
"experimental_sharding_callback",
"extra_tags",
)
namespace_whitelist: list[str]
save_debug_info: bool
function_aliases: dict[str, PolymorphicFunction[..., object]]
experimental_debug_stripper: bool
experimental_io_device: str
experimental_variable_policy: VariablePolicy
experimental_custom_gradients: bool
experimental_image_format: bool
experimental_skip_saver: bool
experimental_sharding_callback: Incomplete | None
extra_tags: Incomplete | None
def __init__(
self,
namespace_whitelist: list[str] | None = None,
@@ -68,6 +72,7 @@ class SaveOptions:
experimental_image_format: bool = False,
experimental_skip_saver: bool = False,
experimental_sharding_callback: Incomplete | None = None,
extra_tags: Incomplete | None = None,
) -> None: ...
def contains_saved_model(export_dir: str | Path) -> bool: ...

View File

@@ -473,15 +473,18 @@ class InsertKeyValueRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KV_FIELD_NUMBER: builtins.int
ALLOW_OVERWRITE_FIELD_NUMBER: builtins.int
allow_overwrite: builtins.bool
@property
def kv(self) -> global___KeyValueEntry: ...
def __init__(
self,
*,
kv: global___KeyValueEntry | None = ...,
allow_overwrite: builtins.bool | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["kv", b"kv"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["kv", b"kv"]) -> None: ...
def ClearField(self, field_name: typing.Literal["allow_overwrite", b"allow_overwrite", "kv", b"kv"]) -> None: ...
global___InsertKeyValueRequest = InsertKeyValueRequest

View File

@@ -267,6 +267,24 @@ FORWARD_BIAS_ACTIVATION: ConvolutionKind.ValueType # 4
FORWARD_GRAPH: ConvolutionKind.ValueType # 5
global___ConvolutionKind = ConvolutionKind
class _NormKind:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _NormKindEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_NormKind.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
LAYER_FWD_INFER: _NormKind.ValueType # 0
LAYER_FWD_TRAIN: _NormKind.ValueType # 1
LAYER_BWD: _NormKind.ValueType # 2
class NormKind(_NormKind, metaclass=_NormKindEnumTypeWrapper):
"""NormKind kind"""
LAYER_FWD_INFER: NormKind.ValueType # 0
LAYER_FWD_TRAIN: NormKind.ValueType # 1
LAYER_BWD: NormKind.ValueType # 2
global___NormKind = NormKind
class _FusedMHAKind:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
@@ -285,6 +303,28 @@ BMM1_OUTPUT_INPUT_TYPE: FusedMHAKind.ValueType # 1
BMM1_OUTPUT_FLOAT: FusedMHAKind.ValueType # 2
global___FusedMHAKind = FusedMHAKind
class _FMHAMaskKind:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _FMHAMaskKindEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_FMHAMaskKind.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
NO_MASK: _FMHAMaskKind.ValueType # 0
PADDING: _FMHAMaskKind.ValueType # 1
CAUSAL: _FMHAMaskKind.ValueType # 2
PADDING_CAUSAL: _FMHAMaskKind.ValueType # 3
ALIBI: _FMHAMaskKind.ValueType # 4
class FMHAMaskKind(_FMHAMaskKind, metaclass=_FMHAMaskKindEnumTypeWrapper):
"""FusedMHAMaskKind kind"""
NO_MASK: FMHAMaskKind.ValueType # 0
PADDING: FMHAMaskKind.ValueType # 1
CAUSAL: FMHAMaskKind.ValueType # 2
PADDING_CAUSAL: FMHAMaskKind.ValueType # 3
ALIBI: FMHAMaskKind.ValueType # 4
global___FMHAMaskKind = FMHAMaskKind
@typing.final
class TensorDescriptorProto(google.protobuf.message.Message):
"""Generic tensor representation."""