mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-07 04:34:28 +08:00
Bump tensorflow to 2.17.* (#12512)
This commit is contained in:
@@ -7,7 +7,7 @@ set -euxo pipefail
|
||||
|
||||
# Whenever you update TENSORFLOW_VERSION here, version should be updated
|
||||
# in stubs/tensorflow/METADATA.toml and vice-versa.
|
||||
TENSORFLOW_VERSION=2.16.1
|
||||
TENSORFLOW_VERSION=2.17.0
|
||||
MYPY_PROTOBUF_VERSION=3.6.0
|
||||
|
||||
# brew install coreutils wget
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Whenever you update version here, TENSORFLOW_VERSION should be updated
|
||||
# in scripts/sync_tensorflow_protobuf_stubs.sh and vice-versa.
|
||||
version = "2.16.*"
|
||||
version = "2.17.*"
|
||||
upstream_repository = "https://github.com/tensorflow/tensorflow"
|
||||
# requires a version of numpy with a `py.typed` file
|
||||
requires = ["numpy>=1.20", "types-protobuf", "types-requests"]
|
||||
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 25.1 on tensorflow==2.16.1 ."
|
||||
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 26.1 on tensorflow==2.17.0 ."
|
||||
partial_stub = true
|
||||
|
||||
[tool.stubtest]
|
||||
|
||||
@@ -546,7 +546,7 @@ class HloInstructionProto(google.protobuf.message.Message):
|
||||
"""
|
||||
|
||||
@property
|
||||
def dot_sparsity(self) -> tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor:
|
||||
def dot_sparsity(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor]:
|
||||
"""Sparsity descriptor for dot operation."""
|
||||
|
||||
def __init__(
|
||||
@@ -622,9 +622,9 @@ class HloInstructionProto(google.protobuf.message.Message):
|
||||
k: builtins.int | None = ...,
|
||||
largest: builtins.bool | None = ...,
|
||||
statistics_viz: tensorflow.compiler.xla.xla_data_pb2.StatisticsViz | None = ...,
|
||||
dot_sparsity: tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor | None = ...,
|
||||
dot_sparsity: collections.abc.Iterable[tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
|
||||
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "k", b"k", "largest", b"largest", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "output_operand_aliasing", b"output_operand_aliasing", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index"]) -> typing.Literal["cross_program_prefetch_index"] | None: ...
|
||||
|
||||
|
||||
@@ -93,6 +93,7 @@ class CompilationLogEntry(google.protobuf.message.Message):
|
||||
DURATION_FIELD_NUMBER: builtins.int
|
||||
TASK_INDEX_FIELD_NUMBER: builtins.int
|
||||
PASS_METRICS_FIELD_NUMBER: builtins.int
|
||||
MODULE_IDS_FIELD_NUMBER: builtins.int
|
||||
stage: global___CompilationLogEntry.CompilationStage.ValueType
|
||||
"""Compilation stage recorded by this log entry."""
|
||||
task_index: builtins.int
|
||||
@@ -111,6 +112,10 @@ class CompilationLogEntry(google.protobuf.message.Message):
|
||||
def pass_metrics(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PassMetrics]:
|
||||
"""Pass specific metrics."""
|
||||
|
||||
@property
|
||||
def module_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""IDs of modules on which the compilation stage was run."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -119,8 +124,9 @@ class CompilationLogEntry(google.protobuf.message.Message):
|
||||
duration: google.protobuf.duration_pb2.Duration | None = ...,
|
||||
task_index: builtins.int | None = ...,
|
||||
pass_metrics: collections.abc.Iterable[global___PassMetrics] | None = ...,
|
||||
module_ids: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["duration", b"duration", "timestamp", b"timestamp"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["duration", b"duration", "pass_metrics", b"pass_metrics", "stage", b"stage", "task_index", b"task_index", "timestamp", b"timestamp"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["duration", b"duration", "module_ids", b"module_ids", "pass_metrics", b"pass_metrics", "stage", b"stage", "task_index", b"task_index", "timestamp", b"timestamp"]) -> None: ...
|
||||
|
||||
global___CompilationLogEntry = CompilationLogEntry
|
||||
|
||||
@@ -44,14 +44,16 @@ class _PrimitiveTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._
|
||||
"""Invalid primitive type to serve as default."""
|
||||
PRED: _PrimitiveType.ValueType # 1
|
||||
"""Predicates are two-state booleans."""
|
||||
S4: _PrimitiveType.ValueType # 21
|
||||
S2: _PrimitiveType.ValueType # 26
|
||||
"""Signed integral values of fixed width."""
|
||||
S4: _PrimitiveType.ValueType # 21
|
||||
S8: _PrimitiveType.ValueType # 2
|
||||
S16: _PrimitiveType.ValueType # 3
|
||||
S32: _PrimitiveType.ValueType # 4
|
||||
S64: _PrimitiveType.ValueType # 5
|
||||
U4: _PrimitiveType.ValueType # 22
|
||||
U2: _PrimitiveType.ValueType # 27
|
||||
"""Unsigned integral values of fixed width."""
|
||||
U4: _PrimitiveType.ValueType # 22
|
||||
U8: _PrimitiveType.ValueType # 6
|
||||
U16: _PrimitiveType.ValueType # 7
|
||||
U32: _PrimitiveType.ValueType # 8
|
||||
@@ -149,14 +151,16 @@ PRIMITIVE_TYPE_INVALID: PrimitiveType.ValueType # 0
|
||||
"""Invalid primitive type to serve as default."""
|
||||
PRED: PrimitiveType.ValueType # 1
|
||||
"""Predicates are two-state booleans."""
|
||||
S4: PrimitiveType.ValueType # 21
|
||||
S2: PrimitiveType.ValueType # 26
|
||||
"""Signed integral values of fixed width."""
|
||||
S4: PrimitiveType.ValueType # 21
|
||||
S8: PrimitiveType.ValueType # 2
|
||||
S16: PrimitiveType.ValueType # 3
|
||||
S32: PrimitiveType.ValueType # 4
|
||||
S64: PrimitiveType.ValueType # 5
|
||||
U4: PrimitiveType.ValueType # 22
|
||||
U2: PrimitiveType.ValueType # 27
|
||||
"""Unsigned integral values of fixed width."""
|
||||
U4: PrimitiveType.ValueType # 22
|
||||
U8: PrimitiveType.ValueType # 6
|
||||
U16: PrimitiveType.ValueType # 7
|
||||
U32: PrimitiveType.ValueType # 8
|
||||
@@ -508,8 +512,8 @@ global___PaddingConfig = PaddingConfig
|
||||
@typing.final
|
||||
class TileProto(google.protobuf.message.Message):
|
||||
"""Describes a tile used in tiling-based layout. Refer to
|
||||
g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for
|
||||
details about tiling-based layout.
|
||||
g3doc/third_party/xla/docs/tiled_layout.md for details about tiling-based
|
||||
layout.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
@@ -532,6 +536,33 @@ class TileProto(google.protobuf.message.Message):
|
||||
|
||||
global___TileProto = TileProto
|
||||
|
||||
@typing.final
|
||||
class SplitConfigProto(google.protobuf.message.Message):
|
||||
"""Describes how data should be split between different memories."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
DIMENSION_FIELD_NUMBER: builtins.int
|
||||
SPLIT_INDICES_FIELD_NUMBER: builtins.int
|
||||
dimension: builtins.int
|
||||
"""The dimension that is split."""
|
||||
@property
|
||||
def split_indices(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""The indices where each split point occurs. For example, if the dimension
|
||||
size is 1024, a split_indices value of {512} indicates a two-way split of
|
||||
data through the middle.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dimension: builtins.int | None = ...,
|
||||
split_indices: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["dimension", b"dimension", "split_indices", b"split_indices"]) -> None: ...
|
||||
|
||||
global___SplitConfigProto = SplitConfigProto
|
||||
|
||||
@typing.final
|
||||
class LayoutProto(google.protobuf.message.Message):
|
||||
"""A layout describes how the array is placed in (1D) memory space. This
|
||||
@@ -560,6 +591,7 @@ class LayoutProto(google.protobuf.message.Message):
|
||||
POINTER_PRIMITIVE_TYPE_FIELD_NUMBER: builtins.int
|
||||
PHYSICAL_SHAPE_FIELD_NUMBER: builtins.int
|
||||
DYNAMIC_SHAPE_METADATA_PREFIX_BYTES_FIELD_NUMBER: builtins.int
|
||||
SPLIT_CONFIGS_FIELD_NUMBER: builtins.int
|
||||
tail_padding_alignment_in_elements: builtins.int
|
||||
"""The shape is padded at the end to multiple of, in terms of number of
|
||||
elements. This is useful when tiling does not bring the shape to certain
|
||||
@@ -632,6 +664,12 @@ class LayoutProto(google.protobuf.message.Message):
|
||||
a physical shape.
|
||||
"""
|
||||
|
||||
@property
|
||||
def split_configs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___SplitConfigProto]:
|
||||
"""The split configurations which describe if/how the data is split between
|
||||
different memories.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -647,9 +685,10 @@ class LayoutProto(google.protobuf.message.Message):
|
||||
pointer_primitive_type: global___PrimitiveType.ValueType | None = ...,
|
||||
physical_shape: global___ShapeProto | None = ...,
|
||||
dynamic_shape_metadata_prefix_bytes: builtins.int | None = ...,
|
||||
split_configs: collections.abc.Iterable[global___SplitConfigProto] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["physical_shape", b"physical_shape"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["dim_level_types", b"dim_level_types", "dim_ordered", b"dim_ordered", "dim_unique", b"dim_unique", "dynamic_shape_metadata_prefix_bytes", b"dynamic_shape_metadata_prefix_bytes", "element_size_in_bits", b"element_size_in_bits", "index_primitive_type", b"index_primitive_type", "memory_space", b"memory_space", "minor_to_major", b"minor_to_major", "physical_shape", b"physical_shape", "pointer_primitive_type", b"pointer_primitive_type", "tail_padding_alignment_in_elements", b"tail_padding_alignment_in_elements", "tiles", b"tiles"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["dim_level_types", b"dim_level_types", "dim_ordered", b"dim_ordered", "dim_unique", b"dim_unique", "dynamic_shape_metadata_prefix_bytes", b"dynamic_shape_metadata_prefix_bytes", "element_size_in_bits", b"element_size_in_bits", "index_primitive_type", b"index_primitive_type", "memory_space", b"memory_space", "minor_to_major", b"minor_to_major", "physical_shape", b"physical_shape", "pointer_primitive_type", b"pointer_primitive_type", "split_configs", b"split_configs", "tail_padding_alignment_in_elements", b"tail_padding_alignment_in_elements", "tiles", b"tiles"]) -> None: ...
|
||||
|
||||
global___LayoutProto = LayoutProto
|
||||
|
||||
@@ -815,8 +854,6 @@ class OpMetadata(google.protobuf.message.Message):
|
||||
SOURCE_FILE_FIELD_NUMBER: builtins.int
|
||||
SOURCE_LINE_FIELD_NUMBER: builtins.int
|
||||
PROFILE_TYPE_FIELD_NUMBER: builtins.int
|
||||
CREATION_PASS_ID_FIELD_NUMBER: builtins.int
|
||||
LOGICAL_CREATION_PASS_ID_FIELD_NUMBER: builtins.int
|
||||
SIZE_OF_GENERATED_CODE_IN_BYTES_FIELD_NUMBER: builtins.int
|
||||
SIZE_OF_MEMORY_WORKING_SET_IN_BYTES_FIELD_NUMBER: builtins.int
|
||||
PROFILE_INFO_FIELD_NUMBER: builtins.int
|
||||
@@ -844,17 +881,6 @@ class OpMetadata(google.protobuf.message.Message):
|
||||
e.g. it could be the file and line of user code that generated the op.
|
||||
"""
|
||||
source_line: builtins.int
|
||||
creation_pass_id: builtins.int
|
||||
"""HloPassMetadata.pass_id of the pass that created this HLO instruction
|
||||
object. Should never be copied between HLO instructions. Zero if unset and
|
||||
-1 if the instruction was created before HLO passes began.
|
||||
"""
|
||||
logical_creation_pass_id: builtins.int
|
||||
"""HloPassMetadata.pass_id of the pass that created the logical functionality
|
||||
that this HLO instruction represents. Should be copied between HLO
|
||||
instructions that correspond across compilation passes. Zero if unset and
|
||||
-1 if the instruction was created before HLO passes began.
|
||||
"""
|
||||
size_of_generated_code_in_bytes: builtins.int
|
||||
"""The footprint of the generated code for the instruction."""
|
||||
size_of_memory_working_set_in_bytes: builtins.int
|
||||
@@ -891,8 +917,6 @@ class OpMetadata(google.protobuf.message.Message):
|
||||
source_file: builtins.str | None = ...,
|
||||
source_line: builtins.int | None = ...,
|
||||
profile_type: collections.abc.Iterable[global___ProfileType.ValueType] | None = ...,
|
||||
creation_pass_id: builtins.int | None = ...,
|
||||
logical_creation_pass_id: builtins.int | None = ...,
|
||||
size_of_generated_code_in_bytes: builtins.int | None = ...,
|
||||
size_of_memory_working_set_in_bytes: builtins.int | None = ...,
|
||||
profile_info: global___OpMetadata.ProfileInfo | None = ...,
|
||||
@@ -901,7 +925,7 @@ class OpMetadata(google.protobuf.message.Message):
|
||||
stack_frame_id: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["profile_info", b"profile_info"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["creation_pass_id", b"creation_pass_id", "deduplicated_name", b"deduplicated_name", "logical_creation_pass_id", b"logical_creation_pass_id", "op_name", b"op_name", "op_type", b"op_type", "preserve_layout", b"preserve_layout", "profile_info", b"profile_info", "profile_type", b"profile_type", "size_of_generated_code_in_bytes", b"size_of_generated_code_in_bytes", "size_of_memory_working_set_in_bytes", b"size_of_memory_working_set_in_bytes", "source_file", b"source_file", "source_line", b"source_line", "stack_frame_id", b"stack_frame_id"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["deduplicated_name", b"deduplicated_name", "op_name", b"op_name", "op_type", b"op_type", "preserve_layout", b"preserve_layout", "profile_info", b"profile_info", "profile_type", b"profile_type", "size_of_generated_code_in_bytes", b"size_of_generated_code_in_bytes", "size_of_memory_working_set_in_bytes", b"size_of_memory_working_set_in_bytes", "source_file", b"source_file", "source_line", b"source_line", "stack_frame_id", b"stack_frame_id"]) -> None: ...
|
||||
|
||||
global___OpMetadata = OpMetadata
|
||||
|
||||
@@ -918,6 +942,7 @@ class ExecutionProfile(google.protobuf.message.Message):
|
||||
COMPUTE_AND_TRANSFER_TIME_NS_FIELD_NUMBER: builtins.int
|
||||
EXECUTABLE_SIZE_IN_BYTES_FIELD_NUMBER: builtins.int
|
||||
PROFILE_CACHE_HIT_FIELD_NUMBER: builtins.int
|
||||
WARMUP_RUN_EXECUTED_FIELD_NUMBER: builtins.int
|
||||
compilation_cache_hit: builtins.bool
|
||||
"""Whether the executable was read from the compilation cache."""
|
||||
compile_time_ms: builtins.int
|
||||
@@ -944,6 +969,10 @@ class ExecutionProfile(google.protobuf.message.Message):
|
||||
"""Whether this profile was drawn from a cache of profiles instead of from
|
||||
execution on the hardware.
|
||||
"""
|
||||
warmup_run_executed: builtins.bool
|
||||
"""Whether a warm-up run of the computation was executed before the
|
||||
measured execution.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -954,8 +983,9 @@ class ExecutionProfile(google.protobuf.message.Message):
|
||||
compute_and_transfer_time_ns: builtins.int | None = ...,
|
||||
executable_size_in_bytes: builtins.int | None = ...,
|
||||
profile_cache_hit: builtins.bool | None = ...,
|
||||
warmup_run_executed: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["compilation_cache_hit", b"compilation_cache_hit", "compile_time_ms", b"compile_time_ms", "compute_and_transfer_time_ns", b"compute_and_transfer_time_ns", "compute_cycle_count", b"compute_cycle_count", "compute_time_ns", b"compute_time_ns", "executable_size_in_bytes", b"executable_size_in_bytes", "profile_cache_hit", b"profile_cache_hit"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["compilation_cache_hit", b"compilation_cache_hit", "compile_time_ms", b"compile_time_ms", "compute_and_transfer_time_ns", b"compute_and_transfer_time_ns", "compute_cycle_count", b"compute_cycle_count", "compute_time_ns", b"compute_time_ns", "executable_size_in_bytes", b"executable_size_in_bytes", "profile_cache_hit", b"profile_cache_hit", "warmup_run_executed", b"warmup_run_executed"]) -> None: ...
|
||||
|
||||
global___ExecutionProfile = ExecutionProfile
|
||||
|
||||
@@ -1139,9 +1169,11 @@ class LiteralProto(google.protobuf.message.Message):
|
||||
|
||||
SHAPE_FIELD_NUMBER: builtins.int
|
||||
PREDS_FIELD_NUMBER: builtins.int
|
||||
S2S_FIELD_NUMBER: builtins.int
|
||||
S4S_FIELD_NUMBER: builtins.int
|
||||
U4S_FIELD_NUMBER: builtins.int
|
||||
S8S_FIELD_NUMBER: builtins.int
|
||||
U2S_FIELD_NUMBER: builtins.int
|
||||
U4S_FIELD_NUMBER: builtins.int
|
||||
U8S_FIELD_NUMBER: builtins.int
|
||||
S32S_FIELD_NUMBER: builtins.int
|
||||
S64S_FIELD_NUMBER: builtins.int
|
||||
@@ -1162,9 +1194,11 @@ class LiteralProto(google.protobuf.message.Message):
|
||||
F8E5M2FNUZS_FIELD_NUMBER: builtins.int
|
||||
F8E4M3FNUZS_FIELD_NUMBER: builtins.int
|
||||
SPARSE_INDICES_FIELD_NUMBER: builtins.int
|
||||
s2s: builtins.bytes
|
||||
s4s: builtins.bytes
|
||||
u4s: builtins.bytes
|
||||
s8s: builtins.bytes
|
||||
u2s: builtins.bytes
|
||||
u4s: builtins.bytes
|
||||
u8s: builtins.bytes
|
||||
f16s: builtins.bytes
|
||||
"""The F16s, BF16s, U16s and S16s are encoded in little endian byte order"""
|
||||
@@ -1204,16 +1238,18 @@ class LiteralProto(google.protobuf.message.Message):
|
||||
def tuple_literals(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___LiteralProto]: ...
|
||||
@property
|
||||
def sparse_indices(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""Next = 26"""
|
||||
"""Next = 28"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
shape: global___ShapeProto | None = ...,
|
||||
preds: collections.abc.Iterable[builtins.bool] | None = ...,
|
||||
s2s: builtins.bytes | None = ...,
|
||||
s4s: builtins.bytes | None = ...,
|
||||
u4s: builtins.bytes | None = ...,
|
||||
s8s: builtins.bytes | None = ...,
|
||||
u2s: builtins.bytes | None = ...,
|
||||
u4s: builtins.bytes | None = ...,
|
||||
u8s: builtins.bytes | None = ...,
|
||||
s32s: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
s64s: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
@@ -1236,7 +1272,7 @@ class LiteralProto(google.protobuf.message.Message):
|
||||
sparse_indices: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["shape", b"shape"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["bf16s", b"bf16s", "c128s", b"c128s", "c64s", b"c64s", "f16s", b"f16s", "f32s", b"f32s", "f64s", b"f64s", "f8e4m3b11fnuzs", b"f8e4m3b11fnuzs", "f8e4m3fns", b"f8e4m3fns", "f8e4m3fnuzs", b"f8e4m3fnuzs", "f8e5m2fnuzs", b"f8e5m2fnuzs", "f8e5m2s", b"f8e5m2s", "preds", b"preds", "s16s", b"s16s", "s32s", b"s32s", "s4s", b"s4s", "s64s", b"s64s", "s8s", b"s8s", "shape", b"shape", "sparse_indices", b"sparse_indices", "tuple_literals", b"tuple_literals", "u16s", b"u16s", "u32s", b"u32s", "u4s", b"u4s", "u64s", b"u64s", "u8s", b"u8s"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["bf16s", b"bf16s", "c128s", b"c128s", "c64s", b"c64s", "f16s", b"f16s", "f32s", b"f32s", "f64s", b"f64s", "f8e4m3b11fnuzs", b"f8e4m3b11fnuzs", "f8e4m3fns", b"f8e4m3fns", "f8e4m3fnuzs", b"f8e4m3fnuzs", "f8e5m2fnuzs", b"f8e5m2fnuzs", "f8e5m2s", b"f8e5m2s", "preds", b"preds", "s16s", b"s16s", "s2s", b"s2s", "s32s", b"s32s", "s4s", b"s4s", "s64s", b"s64s", "s8s", b"s8s", "shape", b"shape", "sparse_indices", b"sparse_indices", "tuple_literals", b"tuple_literals", "u16s", b"u16s", "u2s", b"u2s", "u32s", b"u32s", "u4s", b"u4s", "u64s", b"u64s", "u8s", b"u8s"]) -> None: ...
|
||||
|
||||
global___LiteralProto = LiteralProto
|
||||
|
||||
@@ -1992,15 +2028,113 @@ class PrecisionConfig(google.protobuf.message.Message):
|
||||
PACKED_NIBBLE: PrecisionConfig.Precision.ValueType # 3
|
||||
"""Each U8/S8 value in a tensor actually represents 2 nibble values."""
|
||||
|
||||
class _Algorithm:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _AlgorithmEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[PrecisionConfig._Algorithm.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
ALG_UNSET: PrecisionConfig._Algorithm.ValueType # 0
|
||||
"""If the algorithm is `ALG_UNSET`, we will decide the algorithm based on
|
||||
the operand_precision values (for now).
|
||||
"""
|
||||
ALG_DOT_ANY_F8_ANY_F8_F32: PrecisionConfig._Algorithm.ValueType # 1
|
||||
"""The storage type can be any 8-bit floating point type."""
|
||||
ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: PrecisionConfig._Algorithm.ValueType # 2
|
||||
"""The storage type can be any 8-bit floating point type. Intermediate
|
||||
results will not periodically be promoted to a higher precision. This
|
||||
corresponds to CUBLASLT_MATMUL_DESC_FAST_ACCUM. Triton's
|
||||
maxNumImpreciseAcc=32 setting may be similar.
|
||||
"""
|
||||
ALG_DOT_F16_F16_F16: PrecisionConfig._Algorithm.ValueType # 3
|
||||
ALG_DOT_F16_F16_F32: PrecisionConfig._Algorithm.ValueType # 4
|
||||
ALG_DOT_BF16_BF16_BF16: PrecisionConfig._Algorithm.ValueType # 5
|
||||
ALG_DOT_BF16_BF16_F32: PrecisionConfig._Algorithm.ValueType # 6
|
||||
ALG_DOT_BF16_BF16_F32_X3: PrecisionConfig._Algorithm.ValueType # 7
|
||||
"""An algorithm which uses 3 BF16_BF16_F32 matmuls to achieve better
|
||||
precision.
|
||||
"""
|
||||
ALG_DOT_BF16_BF16_F32_X6: PrecisionConfig._Algorithm.ValueType # 8
|
||||
"""An algorithm which uses 6 BF16_BF16_F32 matmuls to achieve better
|
||||
precision (similar to F32).
|
||||
"""
|
||||
ALG_DOT_TF32_TF32_F32: PrecisionConfig._Algorithm.ValueType # 9
|
||||
ALG_DOT_TF32_TF32_F32_X3: PrecisionConfig._Algorithm.ValueType # 10
|
||||
"""An algorithm which uses 3 TF32_TF32_F32 matmuls to achieve better
|
||||
precision (similar to F32).
|
||||
"""
|
||||
ALG_DOT_F32_F32_F32: PrecisionConfig._Algorithm.ValueType # 11
|
||||
ALG_DOT_F64_F64_F64: PrecisionConfig._Algorithm.ValueType # 12
|
||||
|
||||
class Algorithm(_Algorithm, metaclass=_AlgorithmEnumTypeWrapper):
|
||||
"""The algorithm used to evaluate the instruction.
|
||||
|
||||
The naming convention for the dot instruction is
|
||||
ALG_DOT_{A_TYPE}_{B_TYPE}_{ACCUM_TYPE}[_X{NUM_OPS}] where A_TYPE, B_TYPE
|
||||
and ACCUM_TYPE correspond to the types in the "primitive dot operations"
|
||||
(such as TensorCore operations) and NUM_OPS is the number of such
|
||||
operations used per "primitive tile". When the NUM_OPS
|
||||
field is skipped, it is assumed to be 1. The types mentioned in the name
|
||||
are independent of the storage types.
|
||||
|
||||
In general ATYPE and BTYPE are the precisions that the LHS and RHS of the
|
||||
operation are rounded to and ACCUMTYPE is the accumulation type. If a
|
||||
backend does not support the given algorithm, an error is raised. The
|
||||
Algorithm enum is intended to eventually replace the Precision enum.
|
||||
"""
|
||||
|
||||
ALG_UNSET: PrecisionConfig.Algorithm.ValueType # 0
|
||||
"""If the algorithm is `ALG_UNSET`, we will decide the algorithm based on
|
||||
the operand_precision values (for now).
|
||||
"""
|
||||
ALG_DOT_ANY_F8_ANY_F8_F32: PrecisionConfig.Algorithm.ValueType # 1
|
||||
"""The storage type can be any 8-bit floating point type."""
|
||||
ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: PrecisionConfig.Algorithm.ValueType # 2
|
||||
"""The storage type can be any 8-bit floating point type. Intermediate
|
||||
results will not periodically be promoted to a higher precision. This
|
||||
corresponds to CUBLASLT_MATMUL_DESC_FAST_ACCUM. Triton's
|
||||
maxNumImpreciseAcc=32 setting may be similar.
|
||||
"""
|
||||
ALG_DOT_F16_F16_F16: PrecisionConfig.Algorithm.ValueType # 3
|
||||
ALG_DOT_F16_F16_F32: PrecisionConfig.Algorithm.ValueType # 4
|
||||
ALG_DOT_BF16_BF16_BF16: PrecisionConfig.Algorithm.ValueType # 5
|
||||
ALG_DOT_BF16_BF16_F32: PrecisionConfig.Algorithm.ValueType # 6
|
||||
ALG_DOT_BF16_BF16_F32_X3: PrecisionConfig.Algorithm.ValueType # 7
|
||||
"""An algorithm which uses 3 BF16_BF16_F32 matmuls to achieve better
|
||||
precision.
|
||||
"""
|
||||
ALG_DOT_BF16_BF16_F32_X6: PrecisionConfig.Algorithm.ValueType # 8
|
||||
"""An algorithm which uses 6 BF16_BF16_F32 matmuls to achieve better
|
||||
precision (similar to F32).
|
||||
"""
|
||||
ALG_DOT_TF32_TF32_F32: PrecisionConfig.Algorithm.ValueType # 9
|
||||
ALG_DOT_TF32_TF32_F32_X3: PrecisionConfig.Algorithm.ValueType # 10
|
||||
"""An algorithm which uses 3 TF32_TF32_F32 matmuls to achieve better
|
||||
precision (similar to F32).
|
||||
"""
|
||||
ALG_DOT_F32_F32_F32: PrecisionConfig.Algorithm.ValueType # 11
|
||||
ALG_DOT_F64_F64_F64: PrecisionConfig.Algorithm.ValueType # 12
|
||||
|
||||
OPERAND_PRECISION_FIELD_NUMBER: builtins.int
|
||||
ALGORITHM_FIELD_NUMBER: builtins.int
|
||||
algorithm: global___PrecisionConfig.Algorithm.ValueType
|
||||
"""Currently doesn't do anything, but we plan to support it for dot and
|
||||
possibly more instructions.
|
||||
|
||||
TODO(b/316147294): Support this on GPU and add this to StableHLO as well.
|
||||
|
||||
If this is set, then `operand_precision` should be set to DEFAULT and it
|
||||
will be ignored.
|
||||
"""
|
||||
@property
|
||||
def operand_precision(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[global___PrecisionConfig.Precision.ValueType]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
operand_precision: collections.abc.Iterable[global___PrecisionConfig.Precision.ValueType] | None = ...,
|
||||
algorithm: global___PrecisionConfig.Algorithm.ValueType | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["operand_precision", b"operand_precision"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["algorithm", b"algorithm", "operand_precision", b"operand_precision"]) -> None: ...
|
||||
|
||||
global___PrecisionConfig = PrecisionConfig
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -93,7 +93,7 @@ global___ExternalStatePolicy = ExternalStatePolicy
|
||||
|
||||
@typing.final
|
||||
class AutotuneOptions(google.protobuf.message.Message):
|
||||
"""next: 5"""
|
||||
"""next: 6"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@@ -101,10 +101,12 @@ class AutotuneOptions(google.protobuf.message.Message):
|
||||
CPU_BUDGET_FIELD_NUMBER: builtins.int
|
||||
RAM_BUDGET_FIELD_NUMBER: builtins.int
|
||||
AUTOTUNE_ALGORITHM_FIELD_NUMBER: builtins.int
|
||||
INITIAL_PARALLELISM_FIELD_NUMBER: builtins.int
|
||||
enabled: builtins.bool
|
||||
cpu_budget: builtins.int
|
||||
ram_budget: builtins.int
|
||||
autotune_algorithm: tensorflow.core.framework.model_pb2.AutotuneAlgorithm.ValueType
|
||||
initial_parallelism: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -112,9 +114,10 @@ class AutotuneOptions(google.protobuf.message.Message):
|
||||
cpu_budget: builtins.int | None = ...,
|
||||
ram_budget: builtins.int | None = ...,
|
||||
autotune_algorithm: tensorflow.core.framework.model_pb2.AutotuneAlgorithm.ValueType | None = ...,
|
||||
initial_parallelism: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["autotune_algorithm", b"autotune_algorithm", "cpu_budget", b"cpu_budget", "enabled", b"enabled", "optional_autotune_algorithm", b"optional_autotune_algorithm", "optional_cpu_budget", b"optional_cpu_budget", "optional_enabled", b"optional_enabled", "optional_ram_budget", b"optional_ram_budget", "ram_budget", b"ram_budget"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["autotune_algorithm", b"autotune_algorithm", "cpu_budget", b"cpu_budget", "enabled", b"enabled", "optional_autotune_algorithm", b"optional_autotune_algorithm", "optional_cpu_budget", b"optional_cpu_budget", "optional_enabled", b"optional_enabled", "optional_ram_budget", b"optional_ram_budget", "ram_budget", b"ram_budget"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["autotune_algorithm", b"autotune_algorithm", "cpu_budget", b"cpu_budget", "enabled", b"enabled", "initial_parallelism", b"initial_parallelism", "optional_autotune_algorithm", b"optional_autotune_algorithm", "optional_cpu_budget", b"optional_cpu_budget", "optional_enabled", b"optional_enabled", "optional_initial_parallelism", b"optional_initial_parallelism", "optional_ram_budget", b"optional_ram_budget", "ram_budget", b"ram_budget"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["autotune_algorithm", b"autotune_algorithm", "cpu_budget", b"cpu_budget", "enabled", b"enabled", "initial_parallelism", b"initial_parallelism", "optional_autotune_algorithm", b"optional_autotune_algorithm", "optional_cpu_budget", b"optional_cpu_budget", "optional_enabled", b"optional_enabled", "optional_initial_parallelism", b"optional_initial_parallelism", "optional_ram_budget", b"optional_ram_budget", "ram_budget", b"ram_budget"]) -> None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_autotune_algorithm", b"optional_autotune_algorithm"]) -> typing.Literal["autotune_algorithm"] | None: ...
|
||||
@typing.overload
|
||||
@@ -122,6 +125,8 @@ class AutotuneOptions(google.protobuf.message.Message):
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_enabled", b"optional_enabled"]) -> typing.Literal["enabled"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_initial_parallelism", b"optional_initial_parallelism"]) -> typing.Literal["initial_parallelism"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_ram_budget", b"optional_ram_budget"]) -> typing.Literal["ram_budget"] | None: ...
|
||||
|
||||
global___AutotuneOptions = AutotuneOptions
|
||||
|
||||
@@ -33,8 +33,11 @@ class ResourceHandleProto(google.protobuf.message.Message):
|
||||
DTYPE_FIELD_NUMBER: builtins.int
|
||||
SHAPE_FIELD_NUMBER: builtins.int
|
||||
dtype: tensorflow.core.framework.types_pb2.DataType.ValueType
|
||||
"""Data type of the tensor."""
|
||||
@property
|
||||
def shape(self) -> tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto: ...
|
||||
def shape(self) -> tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto:
|
||||
"""Shape of the tensor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -41,6 +41,7 @@ class TensorProto(google.protobuf.message.Message):
|
||||
UINT64_VAL_FIELD_NUMBER: builtins.int
|
||||
FLOAT8_VAL_FIELD_NUMBER: builtins.int
|
||||
dtype: tensorflow.core.framework.types_pb2.DataType.ValueType
|
||||
"""Data type of the tensor."""
|
||||
version_number: builtins.int
|
||||
"""Only one of the representations below is set, one of "tensor_contents" and
|
||||
the "xxx_val" attributes. We are not using oneof because as oneofs cannot
|
||||
|
||||
@@ -90,6 +90,49 @@ class GPUOptions(google.protobuf.message.Message):
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["device_ordinal", b"device_ordinal", "memory_limit_mb", b"memory_limit_mb", "priority", b"priority"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class StreamMergeOptions(google.protobuf.message.Message):
|
||||
"""Whether to merge data transfer streams into the compute stream in the
|
||||
same stream group. Stream merging helps reduce the overhead caused by
|
||||
stream synchronization, especially when data transfers are frequent. For
|
||||
example, setting "merge_host_to_device_stream = true" will make the
|
||||
compute stream responsible for both computation and host to device memory
|
||||
copy.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
MERGE_HOST_TO_DEVICE_STREAM_FIELD_NUMBER: builtins.int
|
||||
MERGE_DEVICE_TO_HOST_STREAM_FIELD_NUMBER: builtins.int
|
||||
MERGE_DEVICE_TO_DEVICE_STREAM_FIELD_NUMBER: builtins.int
|
||||
merge_host_to_device_stream: builtins.bool
|
||||
"""If true, the compute stream will be used for host_to_device copy as
|
||||
well. It's no longer necessary to record an event before the copy to
|
||||
let the copy stream wait for the compute stream to finish. There is
|
||||
also no need to wait for the copy to complete before executing the
|
||||
callback function.
|
||||
"""
|
||||
merge_device_to_host_stream: builtins.bool
|
||||
"""If true, the compute stream will be used for device_to_host copy as
|
||||
well. It's no longer necessary to record an event before the copy to
|
||||
let the copy stream wait for the compute stream to finish.
|
||||
"""
|
||||
merge_device_to_device_stream: builtins.bool
|
||||
"""If true, the compute stream will be used for device_to_device copy as
|
||||
well. It's no longer necessary to record an event before the copy to
|
||||
let the copy stream wait for the compute stream of the sending device
|
||||
to finish. There is also no need to wait for the compute stream of the
|
||||
receiving device to finish if the copy is within the same device.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
merge_host_to_device_stream: builtins.bool | None = ...,
|
||||
merge_device_to_host_stream: builtins.bool | None = ...,
|
||||
merge_device_to_device_stream: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["merge_device_to_device_stream", b"merge_device_to_device_stream", "merge_device_to_host_stream", b"merge_device_to_host_stream", "merge_host_to_device_stream", b"merge_host_to_device_stream"]) -> None: ...
|
||||
|
||||
VIRTUAL_DEVICES_FIELD_NUMBER: builtins.int
|
||||
NUM_VIRTUAL_DEVICES_PER_GPU_FIELD_NUMBER: builtins.int
|
||||
USE_UNIFIED_MEMORY_FIELD_NUMBER: builtins.int
|
||||
@@ -105,6 +148,9 @@ class GPUOptions(google.protobuf.message.Message):
|
||||
GPU_HOST_MEM_LIMIT_IN_MB_FIELD_NUMBER: builtins.int
|
||||
GPU_HOST_MEM_DISALLOW_GROWTH_FIELD_NUMBER: builtins.int
|
||||
GPU_SYSTEM_MEMORY_SIZE_IN_MB_FIELD_NUMBER: builtins.int
|
||||
POPULATE_PJRT_GPU_CLIENT_CREATION_INFO_FIELD_NUMBER: builtins.int
|
||||
NODE_ID_FIELD_NUMBER: builtins.int
|
||||
STREAM_MERGE_OPTIONS_FIELD_NUMBER: builtins.int
|
||||
num_virtual_devices_per_gpu: builtins.int
|
||||
"""The number of virtual devices to create on each visible GPU. The
|
||||
available memory will be split equally among all virtual devices. If the
|
||||
@@ -200,6 +246,14 @@ class GPUOptions(google.protobuf.message.Message):
|
||||
system memory size for better resource estimation of multi-tenancy(one
|
||||
gpu with multiple model) use case.
|
||||
"""
|
||||
populate_pjrt_gpu_client_creation_info: builtins.bool
|
||||
"""If true, save information needed for created a PjRt GPU client for
|
||||
creating a client with remote devices.
|
||||
"""
|
||||
node_id: builtins.int
|
||||
"""node_id for use when creating a PjRt GPU client with remote devices,
|
||||
which enumerates jobs*tasks from a ServerDef.
|
||||
"""
|
||||
@property
|
||||
def virtual_devices(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___GPUOptions.Experimental.VirtualDevices]:
|
||||
"""The multi virtual device settings. If empty (not set), it will create
|
||||
@@ -242,6 +296,8 @@ class GPUOptions(google.protobuf.message.Message):
|
||||
result in undefined behavior.
|
||||
"""
|
||||
|
||||
@property
|
||||
def stream_merge_options(self) -> global___GPUOptions.Experimental.StreamMergeOptions: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -260,8 +316,12 @@ class GPUOptions(google.protobuf.message.Message):
|
||||
gpu_host_mem_limit_in_mb: builtins.float | None = ...,
|
||||
gpu_host_mem_disallow_growth: builtins.bool | None = ...,
|
||||
gpu_system_memory_size_in_mb: builtins.int | None = ...,
|
||||
populate_pjrt_gpu_client_creation_info: builtins.bool | None = ...,
|
||||
node_id: builtins.int | None = ...,
|
||||
stream_merge_options: global___GPUOptions.Experimental.StreamMergeOptions | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["collective_ring_order", b"collective_ring_order", "disallow_retry_on_allocation_failure", b"disallow_retry_on_allocation_failure", "gpu_host_mem_disallow_growth", b"gpu_host_mem_disallow_growth", "gpu_host_mem_limit_in_mb", b"gpu_host_mem_limit_in_mb", "gpu_system_memory_size_in_mb", b"gpu_system_memory_size_in_mb", "internal_fragmentation_fraction", b"internal_fragmentation_fraction", "kernel_tracker_max_bytes", b"kernel_tracker_max_bytes", "kernel_tracker_max_interval", b"kernel_tracker_max_interval", "kernel_tracker_max_pending", b"kernel_tracker_max_pending", "num_dev_to_dev_copy_streams", b"num_dev_to_dev_copy_streams", "num_virtual_devices_per_gpu", b"num_virtual_devices_per_gpu", "timestamped_allocator", b"timestamped_allocator", "use_cuda_malloc_async", b"use_cuda_malloc_async", "use_unified_memory", b"use_unified_memory", "virtual_devices", b"virtual_devices"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["stream_merge_options", b"stream_merge_options"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["collective_ring_order", b"collective_ring_order", "disallow_retry_on_allocation_failure", b"disallow_retry_on_allocation_failure", "gpu_host_mem_disallow_growth", b"gpu_host_mem_disallow_growth", "gpu_host_mem_limit_in_mb", b"gpu_host_mem_limit_in_mb", "gpu_system_memory_size_in_mb", b"gpu_system_memory_size_in_mb", "internal_fragmentation_fraction", b"internal_fragmentation_fraction", "kernel_tracker_max_bytes", b"kernel_tracker_max_bytes", "kernel_tracker_max_interval", b"kernel_tracker_max_interval", "kernel_tracker_max_pending", b"kernel_tracker_max_pending", "node_id", b"node_id", "num_dev_to_dev_copy_streams", b"num_dev_to_dev_copy_streams", "num_virtual_devices_per_gpu", b"num_virtual_devices_per_gpu", "populate_pjrt_gpu_client_creation_info", b"populate_pjrt_gpu_client_creation_info", "stream_merge_options", b"stream_merge_options", "timestamped_allocator", b"timestamped_allocator", "use_cuda_malloc_async", b"use_cuda_malloc_async", "use_unified_memory", b"use_unified_memory", "virtual_devices", b"virtual_devices"]) -> None: ...
|
||||
|
||||
PER_PROCESS_GPU_MEMORY_FRACTION_FIELD_NUMBER: builtins.int
|
||||
ALLOW_GROWTH_FIELD_NUMBER: builtins.int
|
||||
|
||||
@@ -333,8 +333,8 @@ class RewriterConfig(google.protobuf.message.Message):
|
||||
(default is ON).
|
||||
"""
|
||||
auto_mixed_precision: global___RewriterConfig.Toggle.ValueType
|
||||
"""Optimize data types for CUDA (default is OFF).
|
||||
This will try to use float16 on GPU which is faster.
|
||||
"""Optimize data types for CUDA/oneDNN (default is OFF).
|
||||
This will try to use float16 on GPU/CPU which is faster.
|
||||
Note that this can change the numerical stability of the graph and may
|
||||
require the use of loss scaling to maintain model convergence.
|
||||
"""
|
||||
|
||||
@@ -114,6 +114,7 @@ class Dataset(ABC, Generic[_T1]):
|
||||
element_spec: ContainerGeneric[tf.TypeSpec[Any]] | None = None,
|
||||
compression: _CompressionTypes = None,
|
||||
reader_func: Callable[[Dataset[Dataset[Any]]], Dataset[Any]] | None = None,
|
||||
wait: bool = False,
|
||||
) -> Dataset[Any]: ...
|
||||
# PEP 646 could be used here for a more precise type when better supported.
|
||||
def map(
|
||||
@@ -196,7 +197,7 @@ class Dataset(ABC, Generic[_T1]):
|
||||
self,
|
||||
buffer_size: ScalarTensorCompatible,
|
||||
seed: int | None = None,
|
||||
reshuffle_each_iteration: bool | None = None,
|
||||
reshuffle_each_iteration: bool = True,
|
||||
name: str | None = None,
|
||||
) -> Dataset[_T1]: ...
|
||||
def skip(self, count: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1]: ...
|
||||
|
||||
@@ -47,15 +47,19 @@ class SaveOptions:
|
||||
"experimental_image_format",
|
||||
"experimental_skip_saver",
|
||||
"experimental_sharding_callback",
|
||||
"extra_tags",
|
||||
)
|
||||
namespace_whitelist: list[str]
|
||||
save_debug_info: bool
|
||||
function_aliases: dict[str, PolymorphicFunction[..., object]]
|
||||
experimental_debug_stripper: bool
|
||||
experimental_io_device: str
|
||||
experimental_variable_policy: VariablePolicy
|
||||
experimental_custom_gradients: bool
|
||||
experimental_image_format: bool
|
||||
experimental_skip_saver: bool
|
||||
experimental_sharding_callback: Incomplete | None
|
||||
extra_tags: Incomplete | None
|
||||
def __init__(
|
||||
self,
|
||||
namespace_whitelist: list[str] | None = None,
|
||||
@@ -68,6 +72,7 @@ class SaveOptions:
|
||||
experimental_image_format: bool = False,
|
||||
experimental_skip_saver: bool = False,
|
||||
experimental_sharding_callback: Incomplete | None = None,
|
||||
extra_tags: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
|
||||
def contains_saved_model(export_dir: str | Path) -> bool: ...
|
||||
|
||||
@@ -473,15 +473,18 @@ class InsertKeyValueRequest(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KV_FIELD_NUMBER: builtins.int
|
||||
ALLOW_OVERWRITE_FIELD_NUMBER: builtins.int
|
||||
allow_overwrite: builtins.bool
|
||||
@property
|
||||
def kv(self) -> global___KeyValueEntry: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
kv: global___KeyValueEntry | None = ...,
|
||||
allow_overwrite: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["kv", b"kv"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["kv", b"kv"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["allow_overwrite", b"allow_overwrite", "kv", b"kv"]) -> None: ...
|
||||
|
||||
global___InsertKeyValueRequest = InsertKeyValueRequest
|
||||
|
||||
|
||||
@@ -267,6 +267,24 @@ FORWARD_BIAS_ACTIVATION: ConvolutionKind.ValueType # 4
|
||||
FORWARD_GRAPH: ConvolutionKind.ValueType # 5
|
||||
global___ConvolutionKind = ConvolutionKind
|
||||
|
||||
class _NormKind:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _NormKindEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_NormKind.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
LAYER_FWD_INFER: _NormKind.ValueType # 0
|
||||
LAYER_FWD_TRAIN: _NormKind.ValueType # 1
|
||||
LAYER_BWD: _NormKind.ValueType # 2
|
||||
|
||||
class NormKind(_NormKind, metaclass=_NormKindEnumTypeWrapper):
|
||||
"""NormKind kind"""
|
||||
|
||||
LAYER_FWD_INFER: NormKind.ValueType # 0
|
||||
LAYER_FWD_TRAIN: NormKind.ValueType # 1
|
||||
LAYER_BWD: NormKind.ValueType # 2
|
||||
global___NormKind = NormKind
|
||||
|
||||
class _FusedMHAKind:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
@@ -285,6 +303,28 @@ BMM1_OUTPUT_INPUT_TYPE: FusedMHAKind.ValueType # 1
|
||||
BMM1_OUTPUT_FLOAT: FusedMHAKind.ValueType # 2
|
||||
global___FusedMHAKind = FusedMHAKind
|
||||
|
||||
class _FMHAMaskKind:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _FMHAMaskKindEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_FMHAMaskKind.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
NO_MASK: _FMHAMaskKind.ValueType # 0
|
||||
PADDING: _FMHAMaskKind.ValueType # 1
|
||||
CAUSAL: _FMHAMaskKind.ValueType # 2
|
||||
PADDING_CAUSAL: _FMHAMaskKind.ValueType # 3
|
||||
ALIBI: _FMHAMaskKind.ValueType # 4
|
||||
|
||||
class FMHAMaskKind(_FMHAMaskKind, metaclass=_FMHAMaskKindEnumTypeWrapper):
|
||||
"""FusedMHAMaskKind kind"""
|
||||
|
||||
NO_MASK: FMHAMaskKind.ValueType # 0
|
||||
PADDING: FMHAMaskKind.ValueType # 1
|
||||
CAUSAL: FMHAMaskKind.ValueType # 2
|
||||
PADDING_CAUSAL: FMHAMaskKind.ValueType # 3
|
||||
ALIBI: FMHAMaskKind.ValueType # 4
|
||||
global___FMHAMaskKind = FMHAMaskKind
|
||||
|
||||
@typing.final
|
||||
class TensorDescriptorProto(google.protobuf.message.Message):
|
||||
"""Generic tensor representation."""
|
||||
|
||||
Reference in New Issue
Block a user