Bump tensorflow to 2.16.* (#11696)

This commit is contained in:
Avasam
2024-04-22 10:37:33 -04:00
committed by GitHub
parent 17f1c4628a
commit b3bfdadb45
56 changed files with 4060 additions and 916 deletions

View File

@@ -49,15 +49,9 @@ tensorflow._aliases
# tf.initializers at runtime is <module 'keras.api._v2.keras.initializers' from '...'>
tensorflow.initializers
# Other cursed import magic similar to the one above.
tensorflow.keras.layers.preprocessing
tensorflow.keras.layers.preprocessing.index_lookup
# Another cursed import magic similar to the one above.
tensorflow.distribute.coordinator
tensorflow.distribute.experimental.coordinator
# Layer constructor's always have **kwargs, but only allow a few specific values. PEP 692
# would allow us to specify this with **kwargs and remove the need for these exceptions.
tensorflow.keras.layers.*.__init__
# __call__ in tensorflow classes often allow keyword usage, but
# when you subclass those classes it is not expected to handle keyword case. As an example,
# class MyLayer(tf.keras.layers.Layer):

View File

@@ -1,8 +1,10 @@
version = "2.15.*"
# Whenever you update version here, TENSORFLOW_VERSION should be updated
# in scripts/sync_tensorflow_protobuf_stubs.sh and vice-versa.
version = "2.16.*"
upstream_repository = "https://github.com/tensorflow/tensorflow"
# requires a version of numpy with a `py.typed` file
requires = ["numpy>=1.20", "types-protobuf", "types-requests"]
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 25.1 on tensorflow==2.12.1."
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 25.1 on tensorflow==2.16.1 ."
partial_stub = true
[tool.stubtest]

View File

@@ -1,3 +1,4 @@
import abc
from _typeshed import Incomplete, Unused
from abc import ABC, ABCMeta, abstractmethod
from builtins import bool as _bool
@@ -17,6 +18,7 @@ from tensorflow import (
io as io,
keras as keras,
math as math,
types as types,
)
from tensorflow._aliases import AnyArray, DTypeLike, ShapeLike, Slice, TensorCompatible
from tensorflow.autodiff import GradientTape as GradientTape
@@ -257,7 +259,7 @@ class IndexedSlices(metaclass=ABCMeta):
def __neg__(self) -> IndexedSlices: ...
def consumers(self) -> list[Operation]: ...
class name_scope:
class name_scope(metaclass=abc.ABCMeta):
def __init__(self, name: str) -> None: ...
def __enter__(self) -> str: ...
def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ...

View File

@@ -1,92 +0,0 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
"""
import builtins
import collections.abc
import typing
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import tensorflow.tsl.protobuf.autotuning_pb2
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class AutotuneResults(google.protobuf.message.Message):
"""A collection of algorithms for particular dot/convs. Usually this is "the
best" algorithm for the particular dot/conv, although that's not strictly
required.
Users don't interact with this proto directly. It's used internally to
facilitate ahead-of-time autotuning -- The string used by
xla::{Serialize,Load}AutotuneResults is, internally, a serialization of this
proto.
LINT.IfChange
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@typing.final
class Entry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DEVICE_FIELD_NUMBER: builtins.int
HLO_FIELD_NUMBER: builtins.int
RESULT_FIELD_NUMBER: builtins.int
device: builtins.str
hlo: builtins.str
@property
def result(self) -> tensorflow.tsl.protobuf.autotuning_pb2.AutotuneResult:
"""nb: These results are always tied to a particular version of
cublas/cudnn, but this is *especially* true for cublasLt results. For
cublasLt gemms, the result is an index into the list of candidate
algorithms returned by cublasLt. Different version of cublasLt ->
different list of algos -> different interpretation of results!
"""
def __init__(
self,
*,
device: builtins.str | None = ...,
hlo: builtins.str | None = ...,
result: tensorflow.tsl.protobuf.autotuning_pb2.AutotuneResult | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["result", b"result"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["device", b"device", "hlo", b"hlo", "result", b"result"]) -> None: ...
VERSION_FIELD_NUMBER: builtins.int
DOTS_FIELD_NUMBER: builtins.int
CONVS_FIELD_NUMBER: builtins.int
version: builtins.int
@property
def dots(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___AutotuneResults.Entry]: ...
@property
def convs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___AutotuneResults.Entry]: ...
def __init__(
self,
*,
version: builtins.int | None = ...,
dots: collections.abc.Iterable[global___AutotuneResults.Entry] | None = ...,
convs: collections.abc.Iterable[global___AutotuneResults.Entry] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["convs", b"convs", "dots", b"dots", "version", b"version"]) -> None: ...
global___AutotuneResults = AutotuneResults

View File

@@ -20,6 +20,7 @@ import collections.abc
import sys
import typing
import google.protobuf.any_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
@@ -218,7 +219,7 @@ global___Kind = Kind
@typing.final
class HloInstructionProto(google.protobuf.message.Message):
"""Serialization of HloInstruction.
Next ID: 81
Next ID: 87
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -310,8 +311,11 @@ class HloInstructionProto(google.protobuf.message.Message):
CROSS_PROGRAM_PREFETCH_INDEX_FIELD_NUMBER: builtins.int
PADDING_TYPE_FIELD_NUMBER: builtins.int
CUSTOM_CALL_API_VERSION_FIELD_NUMBER: builtins.int
ASYNC_GROUP_ID_FIELD_NUMBER: builtins.int
ASYNC_EXECUTION_THREAD_FIELD_NUMBER: builtins.int
K_FIELD_NUMBER: builtins.int
LARGEST_FIELD_NUMBER: builtins.int
STATISTICS_VIZ_FIELD_NUMBER: builtins.int
DOT_SPARSITY_FIELD_NUMBER: builtins.int
name: builtins.str
opcode: builtins.str
parameter_number: builtins.int
@@ -420,16 +424,15 @@ class HloInstructionProto(google.protobuf.message.Message):
TODO(b/189822916): Remove this field when all clients are migrated to the
status-returning API.
"""
async_group_id: builtins.int
"""Represents a unique identifier for an async group which consists of an
async start, async done, and zero or more async update operations.
Negative async_group_id is equivalent to no async group id.
"""
async_execution_thread: builtins.str
"""Represents a unique execution thread name for one or more async groups.
Each HLO module may contain a main thread and one or more parallel threads.
Empty async_execution_thread is equivalent to main thread.
"""
k: builtins.int
"""Represents the K value for top-k."""
largest: builtins.bool
"""Represents the largest flag for top-k."""
@property
def shape(self) -> tensorflow.compiler.xla.xla_data_pb2.ShapeProto: ...
@property
@@ -536,6 +539,16 @@ class HloInstructionProto(google.protobuf.message.Message):
def frontend_attributes(self) -> tensorflow.compiler.xla.xla_data_pb2.FrontendAttributes:
"""Frontend attributes to pass to the XLA backend."""
@property
def statistics_viz(self) -> tensorflow.compiler.xla.xla_data_pb2.StatisticsViz:
"""Represents the information for tracking propagation of values within HLO
graph.
"""
@property
def dot_sparsity(self) -> tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor:
"""Sparsity descriptor for dot operation."""
def __init__(
self,
*,
@@ -605,11 +618,14 @@ class HloInstructionProto(google.protobuf.message.Message):
cross_program_prefetch_index: builtins.int | None = ...,
padding_type: tensorflow.compiler.xla.xla_data_pb2.PaddingType.ValueType | None = ...,
custom_call_api_version: global___CustomCallApiVersion.ValueType | None = ...,
async_group_id: builtins.int | None = ...,
async_execution_thread: builtins.str | None = ...,
k: builtins.int | None = ...,
largest: builtins.bool | None = ...,
statistics_viz: tensorflow.compiler.xla.xla_data_pb2.StatisticsViz | None = ...,
dot_sparsity: tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "async_group_id", b"async_group_id", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "output_operand_aliasing", b"output_operand_aliasing", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ...
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "k", b"k", "largest", b"largest", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "output_operand_aliasing", b"output_operand_aliasing", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index"]) -> typing.Literal["cross_program_prefetch_index"] | None: ...
global___HloInstructionProto = HloInstructionProto
@@ -736,7 +752,7 @@ class HloInputOutputAliasProto(google.protobuf.message.Message):
parameter_shape_index={1, 2},
}
This entry indicates that the first paremter's {1, 2} element is
This entry indicates that the first parameter's {1, 2} element is
aliased with the {1} element of the root instruction.
"""
@@ -781,72 +797,54 @@ class HloInputOutputAliasProto(google.protobuf.message.Message):
global___HloInputOutputAliasProto = HloInputOutputAliasProto
@typing.final
class DynamicParameterBindingProto(google.protobuf.message.Message):
class HloBufferDonorProto(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@typing.final
class Binding(google.protobuf.message.Message):
"""A list of bindings which indicates that the `target_param_dim_num` in
the subshape `target_param_index` of parameter `target_param_num`
is a dynamic dimension and its real dynamic size is represented
by `dynamic_param_index` in parameter `dynamic_param_num`.
class BufferDonorEntryProto(google.protobuf.message.Message):
"""The following proto describes an input (described by parameter number and a
ShapeIndex of the parameter) that can donate its butter to any output
tensor. It is similar to HloInputOutputAliasProto, but without a paired
output. For example:
As an example, imagine we have a program:
ENTRY main {
a = f32[] parameter(0)
b = f32[10] parameter(1)
ROOT root = (f32[], f32[10]) tuple(%a, %b)
entry = {
parameter_number=0,
parameter_shape_index={1, 2},
}
Let's say 'b' (param index 1) is a dynamic shape whose input has
an upperbound of 10 and real size is determined at runtime.'a'
represents the real size of b's first dimension.
In this case, the fields are set in the following way:
dynamic_param_num = 1
dynamic_param_index = {}
target_param_num = 0
target_param_index = {}
target_param_dim_num = 0
This entry indicates that the first parameter's {1, 2} element can donate
its buffer.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DYNAMIC_PARAM_NUM_FIELD_NUMBER: builtins.int
DYNAMIC_PARAM_INDEX_FIELD_NUMBER: builtins.int
TARGET_PARAM_NUM_FIELD_NUMBER: builtins.int
TARGET_PARAM_INDEX_FIELD_NUMBER: builtins.int
TARGET_PARAM_DIM_NUM_FIELD_NUMBER: builtins.int
dynamic_param_num: builtins.int
target_param_num: builtins.int
target_param_dim_num: builtins.int
PARAMETER_NUMBER_FIELD_NUMBER: builtins.int
PARAMETER_SHAPE_INDEX_FIELD_NUMBER: builtins.int
parameter_number: builtins.int
"""Number of the parameter in entry computation."""
@property
def dynamic_param_index(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
@property
def target_param_index(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
def parameter_shape_index(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""ShapeIndex of the parameter instruction."""
def __init__(
self,
*,
dynamic_param_num: builtins.int | None = ...,
dynamic_param_index: collections.abc.Iterable[builtins.int] | None = ...,
target_param_num: builtins.int | None = ...,
target_param_index: collections.abc.Iterable[builtins.int] | None = ...,
target_param_dim_num: builtins.int | None = ...,
parameter_number: builtins.int | None = ...,
parameter_shape_index: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["dynamic_param_index", b"dynamic_param_index", "dynamic_param_num", b"dynamic_param_num", "target_param_dim_num", b"target_param_dim_num", "target_param_index", b"target_param_index", "target_param_num", b"target_param_num"]) -> None: ...
def ClearField(self, field_name: typing.Literal["parameter_number", b"parameter_number", "parameter_shape_index", b"parameter_shape_index"]) -> None: ...
ENTRIES_FIELD_NUMBER: builtins.int
@property
def entries(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___DynamicParameterBindingProto.Binding]: ...
def entries(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___HloBufferDonorProto.BufferDonorEntryProto]: ...
def __init__(
self,
*,
entries: collections.abc.Iterable[global___DynamicParameterBindingProto.Binding] | None = ...,
entries: collections.abc.Iterable[global___HloBufferDonorProto.BufferDonorEntryProto] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["entries", b"entries"]) -> None: ...
global___DynamicParameterBindingProto = DynamicParameterBindingProto
global___HloBufferDonorProto = HloBufferDonorProto
@typing.final
class CrossProgramPrefetch(google.protobuf.message.Message):
@@ -870,6 +868,101 @@ class CrossProgramPrefetch(google.protobuf.message.Message):
global___CrossProgramPrefetch = CrossProgramPrefetch
@typing.final
class StackFrameIndexProto(google.protobuf.message.Message):
"""Serialization of stack frames index representations.
Stack frames index presented in four flat arrays:
1. File names array.
2. Function names array.
3. File location array.
4. Frame array.
All reference ids in sub-protos are 1-based positions of the
entity in the flat array.
Ids are 1-based to keep 0 value as representation of non-set property.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@typing.final
class FileLocation(google.protobuf.message.Message):
"""Serialization of file position."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
FILE_NAME_ID_FIELD_NUMBER: builtins.int
FUNCTION_NAME_ID_FIELD_NUMBER: builtins.int
LINE_FIELD_NUMBER: builtins.int
COLUMN_FIELD_NUMBER: builtins.int
file_name_id: builtins.int
"""1-based position of file name."""
function_name_id: builtins.int
"""1-based position of function name."""
line: builtins.int
"""Line number."""
column: builtins.int
"""Column number."""
def __init__(
self,
*,
file_name_id: builtins.int | None = ...,
function_name_id: builtins.int | None = ...,
line: builtins.int | None = ...,
column: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["column", b"column", "file_name_id", b"file_name_id", "function_name_id", b"function_name_id", "line", b"line"]) -> None: ...
@typing.final
class StackFrame(google.protobuf.message.Message):
"""Serialization of frame."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
FILE_LOCATION_ID_FIELD_NUMBER: builtins.int
PARENT_FRAME_ID_FIELD_NUMBER: builtins.int
file_location_id: builtins.int
"""1-based position of file location."""
parent_frame_id: builtins.int
"""1-based position of the parent frame."""
def __init__(
self,
*,
file_location_id: builtins.int | None = ...,
parent_frame_id: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["file_location_id", b"file_location_id", "parent_frame_id", b"parent_frame_id"]) -> None: ...
FILE_NAMES_FIELD_NUMBER: builtins.int
FUNCTION_NAMES_FIELD_NUMBER: builtins.int
FILE_LOCATIONS_FIELD_NUMBER: builtins.int
STACK_FRAMES_FIELD_NUMBER: builtins.int
@property
def file_names(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""Flat index array of file names."""
@property
def function_names(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""Flat index array of function names."""
@property
def file_locations(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___StackFrameIndexProto.FileLocation]:
"""Flat index array of file locations."""
@property
def stack_frames(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___StackFrameIndexProto.StackFrame]:
"""Flat index array of frames."""
def __init__(
self,
*,
file_names: collections.abc.Iterable[builtins.str] | None = ...,
function_names: collections.abc.Iterable[builtins.str] | None = ...,
file_locations: collections.abc.Iterable[global___StackFrameIndexProto.FileLocation] | None = ...,
stack_frames: collections.abc.Iterable[global___StackFrameIndexProto.StackFrame] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["file_locations", b"file_locations", "file_names", b"file_names", "function_names", b"function_names", "stack_frames", b"stack_frames"]) -> None: ...
global___StackFrameIndexProto = StackFrameIndexProto
@typing.final
class HloModuleProto(google.protobuf.message.Message):
"""Serialization of HloModule."""
@@ -907,6 +1000,7 @@ class HloModuleProto(google.protobuf.message.Message):
RELATIVE_SPEEDUP_FIELD_NUMBER: builtins.int
PROFILE_SOURCE_FIELD_NUMBER: builtins.int
COMPILATION_EVENT_FIELD_NUMBER: builtins.int
FINGERPRINT_FIELD_NUMBER: builtins.int
profile_type: global___HloModuleProto.ProfileType.ValueType
"""The optimization profiles that this module contains."""
relative_speedup: builtins.float
@@ -915,6 +1009,8 @@ class HloModuleProto(google.protobuf.message.Message):
"""The source of the optimization profile that this module contains."""
compilation_event: tensorflow.compiler.xla.xla_data_pb2.CompilationEvent.ValueType
"""The compilation event that triggered the use of the profile."""
fingerprint: builtins.str
"""The fingerprint of the unoptimized module this profile was applied to."""
def __init__(
self,
*,
@@ -922,8 +1018,9 @@ class HloModuleProto(google.protobuf.message.Message):
relative_speedup: builtins.float | None = ...,
profile_source: tensorflow.compiler.xla.xla_data_pb2.ProfileSource.ValueType | None = ...,
compilation_event: tensorflow.compiler.xla.xla_data_pb2.CompilationEvent.ValueType | None = ...,
fingerprint: builtins.str | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["compilation_event", b"compilation_event", "profile_source", b"profile_source", "profile_type", b"profile_type", "relative_speedup", b"relative_speedup"]) -> None: ...
def ClearField(self, field_name: typing.Literal["compilation_event", b"compilation_event", "fingerprint", b"fingerprint", "profile_source", b"profile_source", "profile_type", b"profile_type", "relative_speedup", b"relative_speedup"]) -> None: ...
NAME_FIELD_NUMBER: builtins.int
ENTRY_COMPUTATION_NAME_FIELD_NUMBER: builtins.int
@@ -933,7 +1030,7 @@ class HloModuleProto(google.protobuf.message.Message):
ID_FIELD_NUMBER: builtins.int
SCHEDULE_FIELD_NUMBER: builtins.int
INPUT_OUTPUT_ALIAS_FIELD_NUMBER: builtins.int
DYNAMIC_PARAMETER_BINDING_FIELD_NUMBER: builtins.int
BUFFER_DONOR_FIELD_NUMBER: builtins.int
CROSS_PROGRAM_PREFETCHES_FIELD_NUMBER: builtins.int
IS_DYNAMIC_FIELD_NUMBER: builtins.int
SPMD_OUTPUT_SHARDING_FIELD_NUMBER: builtins.int
@@ -941,6 +1038,8 @@ class HloModuleProto(google.protobuf.message.Message):
USE_AUTO_SPMD_PARTITIONING_FIELD_NUMBER: builtins.int
PROFILE_INFO_FIELD_NUMBER: builtins.int
DEVICE_ASSIGNMENT_FIELD_NUMBER: builtins.int
STACK_FRAME_INDEX_FIELD_NUMBER: builtins.int
FRONTEND_ATTRIBUTES_FIELD_NUMBER: builtins.int
name: builtins.str
entry_computation_name: builtins.str
entry_computation_id: builtins.int
@@ -969,7 +1068,9 @@ class HloModuleProto(google.protobuf.message.Message):
"""Describes alias information between inputs and outputs."""
@property
def dynamic_parameter_binding(self) -> global___DynamicParameterBindingProto: ...
def buffer_donor(self) -> global___HloBufferDonorProto:
"""Describes the information of input buffer donors."""
@property
def cross_program_prefetches(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___CrossProgramPrefetch]: ...
@property
@@ -984,6 +1085,14 @@ class HloModuleProto(google.protobuf.message.Message):
def device_assignment(self) -> tensorflow.compiler.xla.xla_data_pb2.DeviceAssignmentProto:
"""DeviceAssignment object information."""
@property
def stack_frame_index(self) -> global___StackFrameIndexProto:
"""Stack frames index."""
@property
def frontend_attributes(self) -> tensorflow.compiler.xla.xla_data_pb2.FrontendAttributes:
"""Frontend attributes to pass to the XLA backend."""
def __init__(
self,
*,
@@ -995,7 +1104,7 @@ class HloModuleProto(google.protobuf.message.Message):
id: builtins.int | None = ...,
schedule: global___HloScheduleProto | None = ...,
input_output_alias: global___HloInputOutputAliasProto | None = ...,
dynamic_parameter_binding: global___DynamicParameterBindingProto | None = ...,
buffer_donor: global___HloBufferDonorProto | None = ...,
cross_program_prefetches: collections.abc.Iterable[global___CrossProgramPrefetch] | None = ...,
is_dynamic: builtins.bool | None = ...,
spmd_output_sharding: tensorflow.compiler.xla.xla_data_pb2.OpSharding | None = ...,
@@ -1003,9 +1112,11 @@ class HloModuleProto(google.protobuf.message.Message):
use_auto_spmd_partitioning: builtins.bool | None = ...,
profile_info: collections.abc.Iterable[global___HloModuleProto.ProfileInfo] | None = ...,
device_assignment: tensorflow.compiler.xla.xla_data_pb2.DeviceAssignmentProto | None = ...,
stack_frame_index: global___StackFrameIndexProto | None = ...,
frontend_attributes: tensorflow.compiler.xla.xla_data_pb2.FrontendAttributes | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["device_assignment", b"device_assignment", "dynamic_parameter_binding", b"dynamic_parameter_binding", "host_program_shape", b"host_program_shape", "input_output_alias", b"input_output_alias", "schedule", b"schedule", "spmd_output_sharding", b"spmd_output_sharding"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["computations", b"computations", "cross_program_prefetches", b"cross_program_prefetches", "device_assignment", b"device_assignment", "dynamic_parameter_binding", b"dynamic_parameter_binding", "entry_computation_id", b"entry_computation_id", "entry_computation_name", b"entry_computation_name", "host_program_shape", b"host_program_shape", "id", b"id", "input_output_alias", b"input_output_alias", "is_dynamic", b"is_dynamic", "name", b"name", "profile_info", b"profile_info", "schedule", b"schedule", "spmd_output_sharding", b"spmd_output_sharding", "spmd_parameters_shardings", b"spmd_parameters_shardings", "use_auto_spmd_partitioning", b"use_auto_spmd_partitioning"]) -> None: ...
def HasField(self, field_name: typing.Literal["buffer_donor", b"buffer_donor", "device_assignment", b"device_assignment", "frontend_attributes", b"frontend_attributes", "host_program_shape", b"host_program_shape", "input_output_alias", b"input_output_alias", "schedule", b"schedule", "spmd_output_sharding", b"spmd_output_sharding", "stack_frame_index", b"stack_frame_index"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["buffer_donor", b"buffer_donor", "computations", b"computations", "cross_program_prefetches", b"cross_program_prefetches", "device_assignment", b"device_assignment", "entry_computation_id", b"entry_computation_id", "entry_computation_name", b"entry_computation_name", "frontend_attributes", b"frontend_attributes", "host_program_shape", b"host_program_shape", "id", b"id", "input_output_alias", b"input_output_alias", "is_dynamic", b"is_dynamic", "name", b"name", "profile_info", b"profile_info", "schedule", b"schedule", "spmd_output_sharding", b"spmd_output_sharding", "spmd_parameters_shardings", b"spmd_parameters_shardings", "stack_frame_index", b"stack_frame_index", "use_auto_spmd_partitioning", b"use_auto_spmd_partitioning"]) -> None: ...
global___HloModuleProto = HloModuleProto
@@ -1435,6 +1546,7 @@ class HloPassMetadata(google.protobuf.message.Message):
MODULE_GROUP_MODULE_IDS_FIELD_NUMBER: builtins.int
START_TIMESTAMP_USEC_FIELD_NUMBER: builtins.int
END_TIMESTAMP_USEC_FIELD_NUMBER: builtins.int
CUSTOM_METADATA_FIELD_NUMBER: builtins.int
pass_id: builtins.int
"""For a given module, pass_id uniquely identifies a run of an HLO pass on
that module. Note that a pass_id may not always refer to the same pass
@@ -1470,6 +1582,10 @@ class HloPassMetadata(google.protobuf.message.Message):
set as the ids of all the modules in the module group. Empty otherwise.
"""
@property
def custom_metadata(self) -> google.protobuf.any_pb2.Any:
"""Custom metadata for the pass."""
def __init__(
self,
*,
@@ -1482,92 +1598,13 @@ class HloPassMetadata(google.protobuf.message.Message):
module_group_module_ids: collections.abc.Iterable[builtins.int] | None = ...,
start_timestamp_usec: builtins.int | None = ...,
end_timestamp_usec: builtins.int | None = ...,
custom_metadata: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["dump_filenames", b"dump_filenames", "end_timestamp_usec", b"end_timestamp_usec", "module_changed", b"module_changed", "module_group_module_ids", b"module_group_module_ids", "module_id", b"module_id", "pass_id", b"pass_id", "pass_name", b"pass_name", "pipeline_name", b"pipeline_name", "start_timestamp_usec", b"start_timestamp_usec"]) -> None: ...
def HasField(self, field_name: typing.Literal["custom_metadata", b"custom_metadata"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["custom_metadata", b"custom_metadata", "dump_filenames", b"dump_filenames", "end_timestamp_usec", b"end_timestamp_usec", "module_changed", b"module_changed", "module_group_module_ids", b"module_group_module_ids", "module_id", b"module_id", "pass_id", b"pass_id", "pass_name", b"pass_name", "pipeline_name", b"pipeline_name", "start_timestamp_usec", b"start_timestamp_usec"]) -> None: ...
global___HloPassMetadata = HloPassMetadata
@typing.final
class EntryFunctionAttributes(google.protobuf.message.Message):
"""Encodes attributes for an entry function."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@typing.final
class ShapeIndex(google.protobuf.message.Message):
"""Acts as the underlying container for an xla::ShapeIndex."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
INDICES_FIELD_NUMBER: builtins.int
@property
def indices(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
def __init__(
self,
*,
indices: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["indices", b"indices"]) -> None: ...
@typing.final
class BufferParameterAttributes(google.protobuf.message.Message):
"""Encodes attributes for a single buffer parameter."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
LMHLO_PARAMS_FIELD_NUMBER: builtins.int
LMHLO_PARAMS_PRESENT_FIELD_NUMBER: builtins.int
LMHLO_PARAM_SHAPE_INDEX_FIELD_NUMBER: builtins.int
LMHLO_CONSTANT_NAME_FIELD_NUMBER: builtins.int
LMHLO_MUST_ALIAS_FIELD_NUMBER: builtins.int
LMHLO_OUTPUT_INDEX_FIELD_NUMBER: builtins.int
lmhlo_params: builtins.int
"""Represents an lmhlo.params function argument attribute."""
lmhlo_params_present: builtins.bool
"""TODO(hanbinyoon): Deprecate when optional fields are available in proto3
(Protocol Buffers v3.15.0).
"""
lmhlo_constant_name: builtins.str
"""Represents an lmhlo.constant_name function argument attribute."""
lmhlo_must_alias: builtins.bool
"""Represents an lmhlo.must_alias function argument attribute."""
@property
def lmhlo_param_shape_index(self) -> global___EntryFunctionAttributes.ShapeIndex:
"""Represents an lmhlo.param_shape_index function argument attribute."""
@property
def lmhlo_output_index(self) -> global___EntryFunctionAttributes.ShapeIndex:
"""Represents an lmhlo.params function argument attribute."""
def __init__(
self,
*,
lmhlo_params: builtins.int | None = ...,
lmhlo_params_present: builtins.bool | None = ...,
lmhlo_param_shape_index: global___EntryFunctionAttributes.ShapeIndex | None = ...,
lmhlo_constant_name: builtins.str | None = ...,
lmhlo_must_alias: builtins.bool | None = ...,
lmhlo_output_index: global___EntryFunctionAttributes.ShapeIndex | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["lmhlo_output_index", b"lmhlo_output_index", "lmhlo_param_shape_index", b"lmhlo_param_shape_index"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["lmhlo_constant_name", b"lmhlo_constant_name", "lmhlo_must_alias", b"lmhlo_must_alias", "lmhlo_output_index", b"lmhlo_output_index", "lmhlo_param_shape_index", b"lmhlo_param_shape_index", "lmhlo_params", b"lmhlo_params", "lmhlo_params_present", b"lmhlo_params_present"]) -> None: ...
BUFFERS_FIELD_NUMBER: builtins.int
RESULT_XLA_SHAPE_FIELD_NUMBER: builtins.int
result_xla_shape: builtins.str
"""xla::Shape in string format."""
@property
def buffers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___EntryFunctionAttributes.BufferParameterAttributes]: ...
def __init__(
self,
*,
buffers: collections.abc.Iterable[global___EntryFunctionAttributes.BufferParameterAttributes] | None = ...,
result_xla_shape: builtins.str | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["buffers", b"buffers", "result_xla_shape", b"result_xla_shape"]) -> None: ...
global___EntryFunctionAttributes = EntryFunctionAttributes
@typing.final
class XlaRuntimeExecutableProto(google.protobuf.message.Message):
"""Encodes the underlying Xla runtime executable compiled from the XLA module."""

View File

@@ -0,0 +1,149 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
Copyright 2018 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
"""
import builtins
import collections.abc
import typing
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class HloProfilePrinterData(google.protobuf.message.Message):
"""Describes how to pretty-print a profile counter array gathered for a specific
HloModule.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@typing.final
class HloInstructionInfo(google.protobuf.message.Message):
"""Pretty-printer information about an HloInstruction."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
LONG_NAME_FIELD_NUMBER: builtins.int
SHORT_NAME_FIELD_NUMBER: builtins.int
CATEGORY_FIELD_NUMBER: builtins.int
FLOP_COUNT_FIELD_NUMBER: builtins.int
TRANSCENDENTAL_COUNT_FIELD_NUMBER: builtins.int
BYTES_ACCESSED_FIELD_NUMBER: builtins.int
OPTIMAL_SECONDS_FIELD_NUMBER: builtins.int
PROFILE_INDEX_FIELD_NUMBER: builtins.int
long_name: builtins.str
short_name: builtins.str
category: builtins.str
flop_count: builtins.float
"""Metrics computed by HloCostAnalysis."""
transcendental_count: builtins.float
bytes_accessed: builtins.int
optimal_seconds: builtins.float
profile_index: builtins.int
"""The index into the profile counters array for the HloInstruction
corresponding to this HloInstructionInfo.
"""
def __init__(
self,
*,
long_name: builtins.str | None = ...,
short_name: builtins.str | None = ...,
category: builtins.str | None = ...,
flop_count: builtins.float | None = ...,
transcendental_count: builtins.float | None = ...,
bytes_accessed: builtins.int | None = ...,
optimal_seconds: builtins.float | None = ...,
profile_index: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["bytes_accessed", b"bytes_accessed", "category", b"category", "flop_count", b"flop_count", "long_name", b"long_name", "optimal_seconds", b"optimal_seconds", "profile_index", b"profile_index", "short_name", b"short_name", "transcendental_count", b"transcendental_count"]) -> None: ...
@typing.final
class HloComputationInfo(google.protobuf.message.Message):
"""Pretty-printer information about an HloComputation."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NAME_FIELD_NUMBER: builtins.int
PROFILE_INDEX_FIELD_NUMBER: builtins.int
INSTRUCTION_INFOS_FIELD_NUMBER: builtins.int
name: builtins.str
profile_index: builtins.int
"""The index into the profile counters array for the HloComputation
corresponding to this HloComputationInfo.
"""
@property
def instruction_infos(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___HloProfilePrinterData.HloInstructionInfo]:
"""HloInstructionInfos for every HloInstruction in the HloComputation for
corresponding to this HloComputattionInfo.
"""
def __init__(
self,
*,
name: builtins.str | None = ...,
profile_index: builtins.int | None = ...,
instruction_infos: collections.abc.Iterable[global___HloProfilePrinterData.HloInstructionInfo] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["instruction_infos", b"instruction_infos", "name", b"name", "profile_index", b"profile_index"]) -> None: ...
@typing.final
class ExtraMetricsEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.str
value: builtins.int
def __init__(
self,
*,
key: builtins.str | None = ...,
value: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
COMPUTATION_INFOS_FIELD_NUMBER: builtins.int
PROFILE_COUNTERS_SIZE_FIELD_NUMBER: builtins.int
EXTRA_METRICS_FIELD_NUMBER: builtins.int
ENTRY_COMPUTATION_FIELD_NUMBER: builtins.int
profile_counters_size: builtins.int
"""The size of the profile counters array we will pretty-print."""
entry_computation: builtins.str
"""Name of the entry computation."""
@property
def computation_infos(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___HloProfilePrinterData.HloComputationInfo]:
"""HloComputationInfos for every HloComputation in the HloModule."""
@property
def extra_metrics(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.int]:
"""Maps extra metric name to the index into the profile counters array."""
def __init__(
self,
*,
computation_infos: collections.abc.Iterable[global___HloProfilePrinterData.HloComputationInfo] | None = ...,
profile_counters_size: builtins.int | None = ...,
extra_metrics: collections.abc.Mapping[builtins.str, builtins.int] | None = ...,
entry_computation: builtins.str | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["computation_infos", b"computation_infos", "entry_computation", b"entry_computation", "extra_metrics", b"extra_metrics", "profile_counters_size", b"profile_counters_size"]) -> None: ...
global___HloProfilePrinterData = HloProfilePrinterData

View File

@@ -4,11 +4,14 @@ isort:skip_file
"""
import builtins
import collections.abc
import sys
import typing
import google.protobuf.any_pb2
import google.protobuf.descriptor
import google.protobuf.duration_pb2
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import google.protobuf.timestamp_pb2
@@ -20,6 +23,44 @@ else:
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class PassMetrics(google.protobuf.message.Message):
"""Defines pass specific metrics."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
MODULE_ID_FIELD_NUMBER: builtins.int
PASS_NAME_FIELD_NUMBER: builtins.int
PASS_DURATION_FIELD_NUMBER: builtins.int
CUSTOM_METRICS_FIELD_NUMBER: builtins.int
module_id: builtins.int
"""Unique ID of the module on which the pass was run."""
pass_name: builtins.str
"""The name of the pass."""
@property
def pass_duration(self) -> google.protobuf.duration_pb2.Duration:
"""Duration of the pass."""
@property
def custom_metrics(self) -> google.protobuf.any_pb2.Any:
"""Custom pass metrics. This is kept opaque, via `google.protobuf.Any`, in
order to decouple pass agnostic compilation logs from possibly proprietary
compiler passes.
"""
def __init__(
self,
*,
module_id: builtins.int | None = ...,
pass_name: builtins.str | None = ...,
pass_duration: google.protobuf.duration_pb2.Duration | None = ...,
custom_metrics: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["custom_metrics", b"custom_metrics", "pass_duration", b"pass_duration"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["custom_metrics", b"custom_metrics", "module_id", b"module_id", "pass_duration", b"pass_duration", "pass_name", b"pass_name"]) -> None: ...
global___PassMetrics = PassMetrics
@typing.final
class CompilationLogEntry(google.protobuf.message.Message):
"""Defines XLA compilation metrics."""
@@ -51,10 +92,13 @@ class CompilationLogEntry(google.protobuf.message.Message):
STAGE_FIELD_NUMBER: builtins.int
DURATION_FIELD_NUMBER: builtins.int
TASK_INDEX_FIELD_NUMBER: builtins.int
PASS_METRICS_FIELD_NUMBER: builtins.int
stage: global___CompilationLogEntry.CompilationStage.ValueType
"""Compilation stage recorded by this log entry."""
task_index: builtins.int
"""Task index from which this log entry was recorded."""
"""Task index from which this log entry was recorded or
-1 if the task index could not be fetched.
"""
@property
def timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp:
"""Time when the event captured by this log entry occurred."""
@@ -63,6 +107,10 @@ class CompilationLogEntry(google.protobuf.message.Message):
def duration(self) -> google.protobuf.duration_pb2.Duration:
"""Duration of the given compilation stage."""
@property
def pass_metrics(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PassMetrics]:
"""Pass specific metrics."""
def __init__(
self,
*,
@@ -70,8 +118,9 @@ class CompilationLogEntry(google.protobuf.message.Message):
stage: global___CompilationLogEntry.CompilationStage.ValueType | None = ...,
duration: google.protobuf.duration_pb2.Duration | None = ...,
task_index: builtins.int | None = ...,
pass_metrics: collections.abc.Iterable[global___PassMetrics] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["duration", b"duration", "timestamp", b"timestamp"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["duration", b"duration", "stage", b"stage", "task_index", b"task_index", "timestamp", b"timestamp"]) -> None: ...
def ClearField(self, field_name: typing.Literal["duration", b"duration", "pass_metrics", b"pass_metrics", "stage", b"stage", "task_index", b"task_index", "timestamp", b"timestamp"]) -> None: ...
global___CompilationLogEntry = CompilationLogEntry

View File

@@ -0,0 +1,71 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
Copyright 2022 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
"""
import builtins
import typing
import google.protobuf.descriptor
import google.protobuf.message
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class TestCompilationEnvironment1(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SOME_FLAG_FIELD_NUMBER: builtins.int
some_flag: builtins.int
def __init__(
self,
*,
some_flag: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["some_flag", b"some_flag"]) -> None: ...
global___TestCompilationEnvironment1 = TestCompilationEnvironment1
@typing.final
class TestCompilationEnvironment2(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SOME_OTHER_FLAG_FIELD_NUMBER: builtins.int
some_other_flag: builtins.int
def __init__(
self,
*,
some_other_flag: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["some_other_flag", b"some_other_flag"]) -> None: ...
global___TestCompilationEnvironment2 = TestCompilationEnvironment2
@typing.final
class TestCompilationEnvironment3(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
A_THIRD_FLAG_FIELD_NUMBER: builtins.int
a_third_flag: builtins.int
def __init__(
self,
*,
a_third_flag: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["a_third_flag", b"a_third_flag"]) -> None: ...
global___TestCompilationEnvironment3 = TestCompilationEnvironment3

View File

@@ -0,0 +1,137 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
Copyright 2023 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
"""
import builtins
import collections.abc
import typing
import google.protobuf.descriptor
import google.protobuf.duration_pb2
import google.protobuf.internal.containers
import google.protobuf.message
import tensorflow.compiler.xla.service.hlo_pb2
import tensorflow.tsl.protobuf.status_pb2
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class CompilerPerfStats(google.protobuf.message.Message):
"""Statistics on how long various parts of compilation took.
Not all durations may be relevant for all producers of this message, in
which irrelevant fields should simply be skipped.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
INIT_DURATION_FIELD_NUMBER: builtins.int
HLO_VERIFICATION_DURATION_FIELD_NUMBER: builtins.int
COMPILATION_PROLOGUE_DURATION_FIELD_NUMBER: builtins.int
COMPILATION_DURATION_FIELD_NUMBER: builtins.int
TOTAL_DURATION_FIELD_NUMBER: builtins.int
@property
def init_duration(self) -> google.protobuf.duration_pb2.Duration:
"""How long did it take to initialize the compiler?"""
@property
def hlo_verification_duration(self) -> google.protobuf.duration_pb2.Duration:
"""How long did it take to verify the HLO?"""
@property
def compilation_prologue_duration(self) -> google.protobuf.duration_pb2.Duration:
"""How long did it take to prepare for compilation after verification?"""
@property
def compilation_duration(self) -> google.protobuf.duration_pb2.Duration:
"""How long did it take to compile?"""
@property
def total_duration(self) -> google.protobuf.duration_pb2.Duration:
"""How long did everything take?"""
def __init__(
self,
*,
init_duration: google.protobuf.duration_pb2.Duration | None = ...,
hlo_verification_duration: google.protobuf.duration_pb2.Duration | None = ...,
compilation_prologue_duration: google.protobuf.duration_pb2.Duration | None = ...,
compilation_duration: google.protobuf.duration_pb2.Duration | None = ...,
total_duration: google.protobuf.duration_pb2.Duration | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["compilation_duration", b"compilation_duration", "compilation_prologue_duration", b"compilation_prologue_duration", "hlo_verification_duration", b"hlo_verification_duration", "init_duration", b"init_duration", "total_duration", b"total_duration"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["compilation_duration", b"compilation_duration", "compilation_prologue_duration", b"compilation_prologue_duration", "hlo_verification_duration", b"hlo_verification_duration", "init_duration", b"init_duration", "total_duration", b"total_duration"]) -> None: ...
global___CompilerPerfStats = CompilerPerfStats
@typing.final
class CompilationResult(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@typing.final
class CountersEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.str
value: builtins.int
def __init__(
self,
*,
key: builtins.str | None = ...,
value: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
HLO_MODULE_FIELD_NUMBER: builtins.int
PERF_STATS_FIELD_NUMBER: builtins.int
STATUS_FIELD_NUMBER: builtins.int
COUNTERS_FIELD_NUMBER: builtins.int
@property
def hlo_module(self) -> tensorflow.compiler.xla.service.hlo_pb2.HloModuleProto:
"""The compiled HLO. Only set when compilation succeeds."""
@property
def perf_stats(self) -> global___CompilerPerfStats:
"""Always set when compilation succeeds. May or may not be set when
compilation fails.
"""
@property
def status(self) -> tensorflow.tsl.protobuf.status_pb2.StatusProto:
"""Always set when compilation fails; never set when compilation succeeds."""
@property
def counters(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.int]:
"""Collects counters collected during compilation. Not every producer may
include counter support at all or any particular counter.
"""
def __init__(
self,
*,
hlo_module: tensorflow.compiler.xla.service.hlo_pb2.HloModuleProto | None = ...,
perf_stats: global___CompilerPerfStats | None = ...,
status: tensorflow.tsl.protobuf.status_pb2.StatusProto | None = ...,
counters: collections.abc.Mapping[builtins.str, builtins.int] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["hlo_module", b"hlo_module", "perf_stats", b"perf_stats", "status", b"status"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["counters", b"counters", "hlo_module", b"hlo_module", "perf_stats", b"perf_stats", "status", b"status"]) -> None: ...
global___CompilationResult = CompilationResult

View File

@@ -1,7 +1,7 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Copyright 2017 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -44,13 +44,15 @@ class _PrimitiveTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._
"""Invalid primitive type to serve as default."""
PRED: _PrimitiveType.ValueType # 1
"""Predicates are two-state booleans."""
S8: _PrimitiveType.ValueType # 2
S4: _PrimitiveType.ValueType # 21
"""Signed integral values of fixed width."""
S8: _PrimitiveType.ValueType # 2
S16: _PrimitiveType.ValueType # 3
S32: _PrimitiveType.ValueType # 4
S64: _PrimitiveType.ValueType # 5
U8: _PrimitiveType.ValueType # 6
U4: _PrimitiveType.ValueType # 22
"""Unsigned integral values of fixed width."""
U8: _PrimitiveType.ValueType # 6
U16: _PrimitiveType.ValueType # 7
U32: _PrimitiveType.ValueType # 8
U64: _PrimitiveType.ValueType # 9
@@ -78,11 +80,35 @@ class _PrimitiveTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._
supported. NaN is represented when the exponent and mantissa bits are all
1s. All other values are finite.
F8E4M3B11FNUZ has 4 exponent bits and 3 mantissa bits and a bias of 11. The
"FNUZ" means only Finite and NaN values are supported; zero is unsigned.
Unlike IEEE types, infinities are not supported. NaN is represented when
the exponent and mantissa bits are all 0s with a sign bit of 1. All other
values are finite.
Support for these dtypes is under development. They do not yet work
properly in most cases.
TODO(b/259609697): Fully support FP8.
"""
F8E4M3FN: _PrimitiveType.ValueType # 20
F8E4M3B11FNUZ: _PrimitiveType.ValueType # 23
F8E5M2FNUZ: _PrimitiveType.ValueType # 24
"""FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915
F8E5M2FNUZ has 5 exponent bits and 2 mantissa bits.
F8E4M3FNUZ has 4 exponent bits and 3 mantissa bits.
The "FNUZ" means only Finite and NaN values are supported; zero is
unsigned. Unlike IEEE types, infinities are not supported. NaN is
represented when the exponent and mantissa bits are all 0s with a sign bit
of 1. All other values are finite.
These differences mean there's an additional exponent value available. To
keep the same dynamic range as an IEEE-like FP8 type, the exponent is
biased one more than would be expected given the number of exponent bits
(8 for Float8E4M3FNUZ and 16 for Float8E5M2FNUZ).
"""
F8E4M3FNUZ: _PrimitiveType.ValueType # 25
C64: _PrimitiveType.ValueType # 15
"""Complex values of fixed width.
Paired F32 (real, imag), as in std::complex<float>.
@@ -123,13 +149,15 @@ PRIMITIVE_TYPE_INVALID: PrimitiveType.ValueType # 0
"""Invalid primitive type to serve as default."""
PRED: PrimitiveType.ValueType # 1
"""Predicates are two-state booleans."""
S8: PrimitiveType.ValueType # 2
S4: PrimitiveType.ValueType # 21
"""Signed integral values of fixed width."""
S8: PrimitiveType.ValueType # 2
S16: PrimitiveType.ValueType # 3
S32: PrimitiveType.ValueType # 4
S64: PrimitiveType.ValueType # 5
U8: PrimitiveType.ValueType # 6
U4: PrimitiveType.ValueType # 22
"""Unsigned integral values of fixed width."""
U8: PrimitiveType.ValueType # 6
U16: PrimitiveType.ValueType # 7
U32: PrimitiveType.ValueType # 8
U64: PrimitiveType.ValueType # 9
@@ -157,11 +185,35 @@ Finite and NaN values are supported. Unlike IEEE types, infinities are not
supported. NaN is represented when the exponent and mantissa bits are all
1s. All other values are finite.
F8E4M3B11FNUZ has 4 exponent bits and 3 mantissa bits and a bias of 11. The
"FNUZ" means only Finite and NaN values are supported; zero is unsigned.
Unlike IEEE types, infinities are not supported. NaN is represented when
the exponent and mantissa bits are all 0s with a sign bit of 1. All other
values are finite.
Support for these dtypes is under development. They do not yet work
properly in most cases.
TODO(b/259609697): Fully support FP8.
"""
F8E4M3FN: PrimitiveType.ValueType # 20
F8E4M3B11FNUZ: PrimitiveType.ValueType # 23
F8E5M2FNUZ: PrimitiveType.ValueType # 24
"""FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915
F8E5M2FNUZ has 5 exponent bits and 2 mantissa bits.
F8E4M3FNUZ has 4 exponent bits and 3 mantissa bits.
The "FNUZ" means only Finite and NaN values are supported; zero is
unsigned. Unlike IEEE types, infinities are not supported. NaN is
represented when the exponent and mantissa bits are all 0s with a sign bit
of 1. All other values are finite.
These differences mean there's an additional exponent value available. To
keep the same dynamic range as an IEEE-like FP8 type, the exponent is
biased one more than would be expected given the number of exponent bits
(8 for Float8E4M3FNUZ and 16 for Float8E5M2FNUZ).
"""
F8E4M3FNUZ: PrimitiveType.ValueType # 25
C64: PrimitiveType.ValueType # 15
"""Complex values of fixed width.
Paired F32 (real, imag), as in std::complex<float>.
@@ -205,6 +257,11 @@ class _DimLevelTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._E
"""The corresponding dimension contains a single coordinate, no sibling
elements for each parent.
"""
DIM_LOOSE_COMPRESSED: _DimLevelType.ValueType # 3
"""The corresponding dimension is Compressed, but with potential trailing
zeros, thus an extra upper bound (high) is used to exclude those zeros.
E.g., indices = [1, 2, 0, 0, 3, 4, 0, 0], position = [(0, 2), (4, 6)].
"""
class DimLevelType(_DimLevelType, metaclass=_DimLevelTypeEnumTypeWrapper):
"""A DimLevelType indicates the encoding method for a dimension in an array.
@@ -222,6 +279,11 @@ DIM_SINGLETON: DimLevelType.ValueType # 2
"""The corresponding dimension contains a single coordinate, no sibling
elements for each parent.
"""
DIM_LOOSE_COMPRESSED: DimLevelType.ValueType # 3
"""The corresponding dimension is Compressed, but with potential trailing
zeros, thus an extra upper bound (high) is used to exclude those zeros.
E.g., indices = [1, 2, 0, 0, 3, 4, 0, 0], position = [(0, 2), (4, 6)].
"""
global___DimLevelType = DimLevelType
class _ProfileType:
@@ -328,6 +390,23 @@ IRFFT: FftType.ValueType # 3
"""Inverse real FFT; fft_length / 2 + 1 complex in,"""
global___FftType = FftType
class _SparsityType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _SparsityTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_SparsityType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
SPARSITY_INVALID: _SparsityType.ValueType # 0
SPARSITY_STRUCTURED_N_M: _SparsityType.ValueType # 1
"""Structured N:M sparsity."""
class SparsityType(_SparsityType, metaclass=_SparsityTypeEnumTypeWrapper): ...
SPARSITY_INVALID: SparsityType.ValueType # 0
SPARSITY_STRUCTURED_N_M: SparsityType.ValueType # 1
"""Structured N:M sparsity."""
global___SparsityType = SparsityType
class _RandomDistribution:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
@@ -474,11 +553,26 @@ class LayoutProto(google.protobuf.message.Message):
DIM_ORDERED_FIELD_NUMBER: builtins.int
MINOR_TO_MAJOR_FIELD_NUMBER: builtins.int
TILES_FIELD_NUMBER: builtins.int
TAIL_PADDING_ALIGNMENT_IN_ELEMENTS_FIELD_NUMBER: builtins.int
ELEMENT_SIZE_IN_BITS_FIELD_NUMBER: builtins.int
MEMORY_SPACE_FIELD_NUMBER: builtins.int
INDEX_PRIMITIVE_TYPE_FIELD_NUMBER: builtins.int
POINTER_PRIMITIVE_TYPE_FIELD_NUMBER: builtins.int
PHYSICAL_SHAPE_FIELD_NUMBER: builtins.int
DYNAMIC_SHAPE_METADATA_PREFIX_BYTES_FIELD_NUMBER: builtins.int
tail_padding_alignment_in_elements: builtins.int
"""The shape is padded at the end to multiple of, in terms of number of
elements. This is useful when tiling does not bring the shape to certain
desired granules. Tiling effectively pads/reshapes/transposes the shape
to another shape. This field pads the total number of elements of that
new shape to a multiple of certain number of elements. This is useful such
as we want a layout which does not tile the data but still requires it to
be padded to certain number of elements.
"""
element_size_in_bits: builtins.int
"""(Optional) Bit size of each element. When unspecified or being 0, default
to ShapeUtil::ByteSizeOfPrimitiveType.
"""
memory_space: builtins.int
"""Memory space where this array resides. The integer field is interpreted in
a backend-specific manner.
@@ -546,6 +640,8 @@ class LayoutProto(google.protobuf.message.Message):
dim_ordered: collections.abc.Iterable[builtins.bool] | None = ...,
minor_to_major: collections.abc.Iterable[builtins.int] | None = ...,
tiles: collections.abc.Iterable[global___TileProto] | None = ...,
tail_padding_alignment_in_elements: builtins.int | None = ...,
element_size_in_bits: builtins.int | None = ...,
memory_space: builtins.int | None = ...,
index_primitive_type: global___PrimitiveType.ValueType | None = ...,
pointer_primitive_type: global___PrimitiveType.ValueType | None = ...,
@@ -553,7 +649,7 @@ class LayoutProto(google.protobuf.message.Message):
dynamic_shape_metadata_prefix_bytes: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["physical_shape", b"physical_shape"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["dim_level_types", b"dim_level_types", "dim_ordered", b"dim_ordered", "dim_unique", b"dim_unique", "dynamic_shape_metadata_prefix_bytes", b"dynamic_shape_metadata_prefix_bytes", "index_primitive_type", b"index_primitive_type", "memory_space", b"memory_space", "minor_to_major", b"minor_to_major", "physical_shape", b"physical_shape", "pointer_primitive_type", b"pointer_primitive_type", "tiles", b"tiles"]) -> None: ...
def ClearField(self, field_name: typing.Literal["dim_level_types", b"dim_level_types", "dim_ordered", b"dim_ordered", "dim_unique", b"dim_unique", "dynamic_shape_metadata_prefix_bytes", b"dynamic_shape_metadata_prefix_bytes", "element_size_in_bits", b"element_size_in_bits", "index_primitive_type", b"index_primitive_type", "memory_space", b"memory_space", "minor_to_major", b"minor_to_major", "physical_shape", b"physical_shape", "pointer_primitive_type", b"pointer_primitive_type", "tail_padding_alignment_in_elements", b"tail_padding_alignment_in_elements", "tiles", b"tiles"]) -> None: ...
global___LayoutProto = LayoutProto
@@ -724,6 +820,9 @@ class OpMetadata(google.protobuf.message.Message):
SIZE_OF_GENERATED_CODE_IN_BYTES_FIELD_NUMBER: builtins.int
SIZE_OF_MEMORY_WORKING_SET_IN_BYTES_FIELD_NUMBER: builtins.int
PROFILE_INFO_FIELD_NUMBER: builtins.int
DEDUPLICATED_NAME_FIELD_NUMBER: builtins.int
PRESERVE_LAYOUT_FIELD_NUMBER: builtins.int
STACK_FRAME_ID_FIELD_NUMBER: builtins.int
op_type: builtins.str
"""The framework op name that generated this XLA op.
@@ -762,6 +861,20 @@ class OpMetadata(google.protobuf.message.Message):
"""The size of the working set, i.e., the amount of memory, used by the
instruction in a compiler-managed fast device memory.
"""
deduplicated_name: builtins.str
"""Deduplicated HLO name for this op. In some cases, we can have multiple
instructions (e.g. fusions) that are considered duplicates. We want to
group them together under the same name so that we can group them together
during analysis (e.g. HLO Op Profile tool in Xprof).
E.g. If we have fusion.1, fusion.2, and fusion.3 marked as duplicates,
fusion.2 and fusion.3 will have deduplicated_name = fusion.1
"""
preserve_layout: builtins.bool
"""Whether to preserve the layout of the HLO op."""
stack_frame_id: builtins.int
"""1-based position of the frame in frames flat array.
Ids are 1-based to keep 0 value as representation of non-set property.
"""
@property
def profile_type(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[global___ProfileType.ValueType]:
"""Deprecated, use [ProfileInfo][profile_type] instead."""
@@ -783,9 +896,12 @@ class OpMetadata(google.protobuf.message.Message):
size_of_generated_code_in_bytes: builtins.int | None = ...,
size_of_memory_working_set_in_bytes: builtins.int | None = ...,
profile_info: global___OpMetadata.ProfileInfo | None = ...,
deduplicated_name: builtins.str | None = ...,
preserve_layout: builtins.bool | None = ...,
stack_frame_id: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["profile_info", b"profile_info"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["creation_pass_id", b"creation_pass_id", "logical_creation_pass_id", b"logical_creation_pass_id", "op_name", b"op_name", "op_type", b"op_type", "profile_info", b"profile_info", "profile_type", b"profile_type", "size_of_generated_code_in_bytes", b"size_of_generated_code_in_bytes", "size_of_memory_working_set_in_bytes", b"size_of_memory_working_set_in_bytes", "source_file", b"source_file", "source_line", b"source_line"]) -> None: ...
def ClearField(self, field_name: typing.Literal["creation_pass_id", b"creation_pass_id", "deduplicated_name", b"deduplicated_name", "logical_creation_pass_id", b"logical_creation_pass_id", "op_name", b"op_name", "op_type", b"op_type", "preserve_layout", b"preserve_layout", "profile_info", b"profile_info", "profile_type", b"profile_type", "size_of_generated_code_in_bytes", b"size_of_generated_code_in_bytes", "size_of_memory_working_set_in_bytes", b"size_of_memory_working_set_in_bytes", "source_file", b"source_file", "source_line", b"source_line", "stack_frame_id", b"stack_frame_id"]) -> None: ...
global___OpMetadata = OpMetadata
@@ -1023,6 +1139,8 @@ class LiteralProto(google.protobuf.message.Message):
SHAPE_FIELD_NUMBER: builtins.int
PREDS_FIELD_NUMBER: builtins.int
S4S_FIELD_NUMBER: builtins.int
U4S_FIELD_NUMBER: builtins.int
S8S_FIELD_NUMBER: builtins.int
U8S_FIELD_NUMBER: builtins.int
S32S_FIELD_NUMBER: builtins.int
@@ -1040,7 +1158,12 @@ class LiteralProto(google.protobuf.message.Message):
S16S_FIELD_NUMBER: builtins.int
F8E5M2S_FIELD_NUMBER: builtins.int
F8E4M3FNS_FIELD_NUMBER: builtins.int
F8E4M3B11FNUZS_FIELD_NUMBER: builtins.int
F8E5M2FNUZS_FIELD_NUMBER: builtins.int
F8E4M3FNUZS_FIELD_NUMBER: builtins.int
SPARSE_INDICES_FIELD_NUMBER: builtins.int
s4s: builtins.bytes
u4s: builtins.bytes
s8s: builtins.bytes
u8s: builtins.bytes
f16s: builtins.bytes
@@ -1050,6 +1173,9 @@ class LiteralProto(google.protobuf.message.Message):
s16s: builtins.bytes
f8e5m2s: builtins.bytes
f8e4m3fns: builtins.bytes
f8e4m3b11fnuzs: builtins.bytes
f8e5m2fnuzs: builtins.bytes
f8e4m3fnuzs: builtins.bytes
@property
def shape(self) -> global___ShapeProto: ...
@property
@@ -1078,13 +1204,15 @@ class LiteralProto(google.protobuf.message.Message):
def tuple_literals(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___LiteralProto]: ...
@property
def sparse_indices(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""Next = 21"""
"""Next = 26"""
def __init__(
self,
*,
shape: global___ShapeProto | None = ...,
preds: collections.abc.Iterable[builtins.bool] | None = ...,
s4s: builtins.bytes | None = ...,
u4s: builtins.bytes | None = ...,
s8s: builtins.bytes | None = ...,
u8s: builtins.bytes | None = ...,
s32s: collections.abc.Iterable[builtins.int] | None = ...,
@@ -1102,10 +1230,13 @@ class LiteralProto(google.protobuf.message.Message):
s16s: builtins.bytes | None = ...,
f8e5m2s: builtins.bytes | None = ...,
f8e4m3fns: builtins.bytes | None = ...,
f8e4m3b11fnuzs: builtins.bytes | None = ...,
f8e5m2fnuzs: builtins.bytes | None = ...,
f8e4m3fnuzs: builtins.bytes | None = ...,
sparse_indices: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["shape", b"shape"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["bf16s", b"bf16s", "c128s", b"c128s", "c64s", b"c64s", "f16s", b"f16s", "f32s", b"f32s", "f64s", b"f64s", "f8e4m3fns", b"f8e4m3fns", "f8e5m2s", b"f8e5m2s", "preds", b"preds", "s16s", b"s16s", "s32s", b"s32s", "s64s", b"s64s", "s8s", b"s8s", "shape", b"shape", "sparse_indices", b"sparse_indices", "tuple_literals", b"tuple_literals", "u16s", b"u16s", "u32s", b"u32s", "u64s", b"u64s", "u8s", b"u8s"]) -> None: ...
def ClearField(self, field_name: typing.Literal["bf16s", b"bf16s", "c128s", b"c128s", "c64s", b"c64s", "f16s", b"f16s", "f32s", b"f32s", "f64s", b"f64s", "f8e4m3b11fnuzs", b"f8e4m3b11fnuzs", "f8e4m3fns", b"f8e4m3fns", "f8e4m3fnuzs", b"f8e4m3fnuzs", "f8e5m2fnuzs", b"f8e5m2fnuzs", "f8e5m2s", b"f8e5m2s", "preds", b"preds", "s16s", b"s16s", "s32s", b"s32s", "s4s", b"s4s", "s64s", b"s64s", "s8s", b"s8s", "shape", b"shape", "sparse_indices", b"sparse_indices", "tuple_literals", b"tuple_literals", "u16s", b"u16s", "u32s", b"u32s", "u4s", b"u4s", "u64s", b"u64s", "u8s", b"u8s"]) -> None: ...
global___LiteralProto = LiteralProto
@@ -1392,6 +1523,44 @@ class DotDimensionNumbers(google.protobuf.message.Message):
global___DotDimensionNumbers = DotDimensionNumbers
@typing.final
class SparsityDescriptor(google.protobuf.message.Message):
"""Contains sparsity metadata for a sparse dot operation.
The only supported type atm is structured 2:4 sparsity, which is natively
supported on NVidia GPUs.
Restrictions:
- only one operand of the dot operation may be sparse;
- only the contracting dimension may be sparse.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TYPE_FIELD_NUMBER: builtins.int
INDEX_FIELD_NUMBER: builtins.int
DIMENSION_FIELD_NUMBER: builtins.int
N_FIELD_NUMBER: builtins.int
M_FIELD_NUMBER: builtins.int
type: global___SparsityType.ValueType
index: builtins.int
"""Sparse operand index (0 or 1)."""
dimension: builtins.int
"""Sparse dimension number."""
n: builtins.int
"""Structured N:M sparsity (N < M)."""
m: builtins.int
def __init__(
self,
*,
type: global___SparsityType.ValueType | None = ...,
index: builtins.int | None = ...,
dimension: builtins.int | None = ...,
n: builtins.int | None = ...,
m: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["dimension", b"dimension", "index", b"index", "m", b"m", "n", b"n", "type", b"type"]) -> None: ...
global___SparsityDescriptor = SparsityDescriptor
@typing.final
class TriangularSolveOptions(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -1462,6 +1631,23 @@ class CholeskyOptions(google.protobuf.message.Message):
global___CholeskyOptions = CholeskyOptions
@typing.final
class SortOptions(google.protobuf.message.Message):
"""Attributes of the sort custom call (cub::DeviceRadixSort)."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DESCENDING_FIELD_NUMBER: builtins.int
descending: builtins.bool
def __init__(
self,
*,
descending: builtins.bool | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["descending", b"descending"]) -> None: ...
global___SortOptions = SortOptions
@typing.final
class FrontendAttributes(google.protobuf.message.Message):
"""Generic map of attributes used to pass hints / configuration options from
@@ -1498,6 +1684,54 @@ class FrontendAttributes(google.protobuf.message.Message):
global___FrontendAttributes = FrontendAttributes
@typing.final
class Statistic(google.protobuf.message.Message):
"""Represents a single statistic to track."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
STAT_NAME_FIELD_NUMBER: builtins.int
STAT_VAL_FIELD_NUMBER: builtins.int
stat_name: builtins.str
"""Must be a single word consisting of any alphanumeric characters"""
stat_val: builtins.float
"""Must be within a range of [0, 100], in order for the graph dumper to
properly render the statistic onto the graph.
"""
def __init__(
self,
*,
stat_name: builtins.str | None = ...,
stat_val: builtins.float | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["stat_name", b"stat_name", "stat_val", b"stat_val"]) -> None: ...
global___Statistic = Statistic
@typing.final
class StatisticsViz(google.protobuf.message.Message):
"""Represents the information needed to visualize propagation statistics when
rendering an HLO graph. This includes an array of statistics as well as the
index of the statistic to render.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
STAT_INDEX_TO_VISUALIZE_FIELD_NUMBER: builtins.int
STATISTICS_FIELD_NUMBER: builtins.int
stat_index_to_visualize: builtins.int
@property
def statistics(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Statistic]: ...
def __init__(
self,
*,
stat_index_to_visualize: builtins.int | None = ...,
statistics: collections.abc.Iterable[global___Statistic] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["stat_index_to_visualize", b"stat_index_to_visualize", "statistics", b"statistics"]) -> None: ...
global___StatisticsViz = StatisticsViz
@typing.final
class OpSharding(google.protobuf.message.Message):
"""LINT.IfChange"""
@@ -1524,6 +1758,10 @@ class OpSharding(google.protobuf.message.Message):
"""This op is manually sharded: the shapes are already partitioned and the
partitioner should not change this op.
"""
UNKNOWN: OpSharding._Type.ValueType # 5
"""This sharding is a placeholder sharding with lowest precedence, it can be
overwriten by any other shardings.
"""
class Type(_Type, metaclass=_TypeEnumTypeWrapper): ...
REPLICATED: OpSharding.Type.ValueType # 0
@@ -1540,6 +1778,43 @@ class OpSharding(google.protobuf.message.Message):
"""This op is manually sharded: the shapes are already partitioned and the
partitioner should not change this op.
"""
UNKNOWN: OpSharding.Type.ValueType # 5
"""This sharding is a placeholder sharding with lowest precedence, it can be
overwriten by any other shardings.
"""
class _ShardGroupType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _ShardGroupTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[OpSharding._ShardGroupType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
AS: OpSharding._ShardGroupType.ValueType # 0
"""This op will be sharded exactly the same as the other op. (hard
restriction)
"""
LIKE: OpSharding._ShardGroupType.ValueType # 1
"""This op will try to allow sharding propagation within the same group even
there is no data dependencies among them, but there is no guarantee that
the final shardings within the same group will be exactly the same. (soft
restriction)
"""
class ShardGroupType(_ShardGroupType, metaclass=_ShardGroupTypeEnumTypeWrapper):
"""Used to decide whether this op is to be sharded like some other ops, or to
which other ops will be sharded like.
"""
AS: OpSharding.ShardGroupType.ValueType # 0
"""This op will be sharded exactly the same as the other op. (hard
restriction)
"""
LIKE: OpSharding.ShardGroupType.ValueType # 1
"""This op will try to allow sharding propagation within the same group even
there is no data dependencies among them, but there is no guarantee that
the final shardings within the same group will be exactly the same. (soft
restriction)
"""
TYPE_FIELD_NUMBER: builtins.int
TILE_SHAPE_FIELD_NUMBER: builtins.int
@@ -1549,12 +1824,22 @@ class OpSharding(google.protobuf.message.Message):
REPLICATE_ON_LAST_TILE_DIM_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
LAST_TILE_DIMS_FIELD_NUMBER: builtins.int
IOTA_RESHAPE_DIMS_FIELD_NUMBER: builtins.int
IOTA_TRANSPOSE_PERM_FIELD_NUMBER: builtins.int
IS_SHARD_GROUP_FIELD_NUMBER: builtins.int
SHARD_GROUP_ID_FIELD_NUMBER: builtins.int
SHARD_GROUP_TYPE_FIELD_NUMBER: builtins.int
type: global___OpSharding.Type.ValueType
replicate_on_last_tile_dim: builtins.bool
"""Only used for OTHER type. If true, data is sharded according to other
dimensions of tile_assignment(), but replicated across devices along the
last dimension. (Experimental)
"""
is_shard_group: builtins.bool
"""This field decides whether this op is in a shard group."""
shard_group_id: builtins.int
"""This field is used to store the unique id of the shard group."""
shard_group_type: global___OpSharding.ShardGroupType.ValueType
@property
def tile_shape(self) -> global___ShapeProto:
"""The shape of the sharded tile."""
@@ -1570,6 +1855,7 @@ class OpSharding(google.protobuf.message.Message):
def tile_assignment_devices(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""Flattened list of device IDs. The order of flattening is the same as used
by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
Only one of tile_assignment_devices and iota_dimensions shall be non-empty.
"""
@property
@@ -1600,6 +1886,19 @@ class OpSharding(google.protobuf.message.Message):
unreduced sharding type respectively.
"""
@property
def iota_reshape_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""Dimensions used to reshape the 1D iota array of device IDs.
Only one of tile_assignment_devices and iota_reshape_dims shall be
non-empty.
"""
@property
def iota_transpose_perm(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
"""Dimension permutations to transposed the iota array reshaped to
iota_reshape_dims. This must have the same size as iota_reshape_dims.
"""
def __init__(
self,
*,
@@ -1611,9 +1910,14 @@ class OpSharding(google.protobuf.message.Message):
replicate_on_last_tile_dim: builtins.bool | None = ...,
metadata: collections.abc.Iterable[global___OpMetadata] | None = ...,
last_tile_dims: collections.abc.Iterable[global___OpSharding.Type.ValueType] | None = ...,
iota_reshape_dims: collections.abc.Iterable[builtins.int] | None = ...,
iota_transpose_perm: collections.abc.Iterable[builtins.int] | None = ...,
is_shard_group: builtins.bool | None = ...,
shard_group_id: builtins.int | None = ...,
shard_group_type: global___OpSharding.ShardGroupType.ValueType | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["tile_shape", b"tile_shape"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["last_tile_dims", b"last_tile_dims", "metadata", b"metadata", "replicate_on_last_tile_dim", b"replicate_on_last_tile_dim", "tile_assignment_devices", b"tile_assignment_devices", "tile_assignment_dimensions", b"tile_assignment_dimensions", "tile_shape", b"tile_shape", "tuple_shardings", b"tuple_shardings", "type", b"type"]) -> None: ...
def ClearField(self, field_name: typing.Literal["iota_reshape_dims", b"iota_reshape_dims", "iota_transpose_perm", b"iota_transpose_perm", "is_shard_group", b"is_shard_group", "last_tile_dims", b"last_tile_dims", "metadata", b"metadata", "replicate_on_last_tile_dim", b"replicate_on_last_tile_dim", "shard_group_id", b"shard_group_id", "shard_group_type", b"shard_group_type", "tile_assignment_devices", b"tile_assignment_devices", "tile_assignment_dimensions", b"tile_assignment_dimensions", "tile_shape", b"tile_shape", "tuple_shardings", b"tuple_shardings", "type", b"type"]) -> None: ...
global___OpSharding = OpSharding

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,99 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import collections.abc
import typing
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import tensorflow.core.framework.full_type_pb2
import tensorflow.core.framework.tensor_shape_pb2
import tensorflow.core.framework.types_pb2
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class CppShapeInferenceResult(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@typing.final
class HandleShapeAndType(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SHAPE_FIELD_NUMBER: builtins.int
DTYPE_FIELD_NUMBER: builtins.int
TYPE_FIELD_NUMBER: builtins.int
dtype: tensorflow.core.framework.types_pb2.DataType.ValueType
@property
def shape(self) -> tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto: ...
@property
def type(self) -> tensorflow.core.framework.full_type_pb2.FullTypeDef: ...
def __init__(
self,
*,
shape: tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto | None = ...,
dtype: tensorflow.core.framework.types_pb2.DataType.ValueType | None = ...,
type: tensorflow.core.framework.full_type_pb2.FullTypeDef | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["shape", b"shape", "type", b"type"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["dtype", b"dtype", "shape", b"shape", "type", b"type"]) -> None: ...
@typing.final
class HandleData(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
IS_SET_FIELD_NUMBER: builtins.int
SHAPE_AND_TYPE_FIELD_NUMBER: builtins.int
is_set: builtins.bool
@property
def shape_and_type(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___CppShapeInferenceResult.HandleShapeAndType]:
"""Only valid if <is_set>."""
def __init__(
self,
*,
is_set: builtins.bool | None = ...,
shape_and_type: collections.abc.Iterable[global___CppShapeInferenceResult.HandleShapeAndType] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["is_set", b"is_set", "shape_and_type", b"shape_and_type"]) -> None: ...
SHAPE_FIELD_NUMBER: builtins.int
HANDLE_DATA_FIELD_NUMBER: builtins.int
@property
def shape(self) -> tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto: ...
@property
def handle_data(self) -> global___CppShapeInferenceResult.HandleData: ...
def __init__(
self,
*,
shape: tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto | None = ...,
handle_data: global___CppShapeInferenceResult.HandleData | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["handle_data", b"handle_data", "shape", b"shape"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["handle_data", b"handle_data", "shape", b"shape"]) -> None: ...
global___CppShapeInferenceResult = CppShapeInferenceResult
@typing.final
class CppShapeInferenceInputsNeeded(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
INPUT_TENSORS_NEEDED_FIELD_NUMBER: builtins.int
INPUT_TENSORS_AS_SHAPES_NEEDED_FIELD_NUMBER: builtins.int
@property
def input_tensors_needed(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
@property
def input_tensors_as_shapes_needed(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
def __init__(
self,
*,
input_tensors_needed: collections.abc.Iterable[builtins.int] | None = ...,
input_tensors_as_shapes_needed: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["input_tensors_as_shapes_needed", b"input_tensors_as_shapes_needed", "input_tensors_needed", b"input_tensors_needed"]) -> None: ...
global___CppShapeInferenceInputsNeeded = CppShapeInferenceInputsNeeded

View File

@@ -4,10 +4,12 @@ isort:skip_file
"""
import builtins
import collections.abc
import sys
import typing
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import tensorflow.core.framework.model_pb2
@@ -198,7 +200,7 @@ global___DistributeOptions = DistributeOptions
@typing.final
class OptimizationOptions(google.protobuf.message.Message):
"""next: 20"""
"""next: 22"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -213,6 +215,7 @@ class OptimizationOptions(google.protobuf.message.Message):
SHUFFLE_AND_REPEAT_FUSION_FIELD_NUMBER: builtins.int
FILTER_PARALLELIZATION_FIELD_NUMBER: builtins.int
INJECT_PREFETCH_FIELD_NUMBER: builtins.int
SEQ_INTERLEAVE_PREFETCH_FIELD_NUMBER: builtins.int
apply_default_optimizations: builtins.bool
filter_fusion: builtins.bool
map_and_batch_fusion: builtins.bool
@@ -224,6 +227,7 @@ class OptimizationOptions(google.protobuf.message.Message):
shuffle_and_repeat_fusion: builtins.bool
filter_parallelization: builtins.bool
inject_prefetch: builtins.bool
seq_interleave_prefetch: builtins.bool
def __init__(
self,
*,
@@ -238,9 +242,10 @@ class OptimizationOptions(google.protobuf.message.Message):
shuffle_and_repeat_fusion: builtins.bool | None = ...,
filter_parallelization: builtins.bool | None = ...,
inject_prefetch: builtins.bool | None = ...,
seq_interleave_prefetch: builtins.bool | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["apply_default_optimizations", b"apply_default_optimizations", "filter_fusion", b"filter_fusion", "filter_parallelization", b"filter_parallelization", "inject_prefetch", b"inject_prefetch", "map_and_batch_fusion", b"map_and_batch_fusion", "map_and_filter_fusion", b"map_and_filter_fusion", "map_fusion", b"map_fusion", "map_parallelization", b"map_parallelization", "noop_elimination", b"noop_elimination", "optional_apply_default_optimizations", b"optional_apply_default_optimizations", "optional_filter_fusion", b"optional_filter_fusion", "optional_filter_parallelization", b"optional_filter_parallelization", "optional_inject_prefetch", b"optional_inject_prefetch", "optional_map_and_batch_fusion", b"optional_map_and_batch_fusion", "optional_map_and_filter_fusion", b"optional_map_and_filter_fusion", "optional_map_fusion", b"optional_map_fusion", "optional_map_parallelization", b"optional_map_parallelization", "optional_noop_elimination", b"optional_noop_elimination", "optional_parallel_batch", b"optional_parallel_batch", "optional_shuffle_and_repeat_fusion", b"optional_shuffle_and_repeat_fusion", "parallel_batch", b"parallel_batch", "shuffle_and_repeat_fusion", b"shuffle_and_repeat_fusion"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["apply_default_optimizations", b"apply_default_optimizations", "filter_fusion", b"filter_fusion", "filter_parallelization", b"filter_parallelization", "inject_prefetch", b"inject_prefetch", "map_and_batch_fusion", b"map_and_batch_fusion", "map_and_filter_fusion", b"map_and_filter_fusion", "map_fusion", b"map_fusion", "map_parallelization", b"map_parallelization", "noop_elimination", b"noop_elimination", "optional_apply_default_optimizations", b"optional_apply_default_optimizations", "optional_filter_fusion", b"optional_filter_fusion", "optional_filter_parallelization", b"optional_filter_parallelization", "optional_inject_prefetch", b"optional_inject_prefetch", "optional_map_and_batch_fusion", b"optional_map_and_batch_fusion", "optional_map_and_filter_fusion", b"optional_map_and_filter_fusion", "optional_map_fusion", b"optional_map_fusion", "optional_map_parallelization", b"optional_map_parallelization", "optional_noop_elimination", b"optional_noop_elimination", "optional_parallel_batch", b"optional_parallel_batch", "optional_shuffle_and_repeat_fusion", b"optional_shuffle_and_repeat_fusion", "parallel_batch", b"parallel_batch", "shuffle_and_repeat_fusion", b"shuffle_and_repeat_fusion"]) -> None: ...
def HasField(self, field_name: typing.Literal["apply_default_optimizations", b"apply_default_optimizations", "filter_fusion", b"filter_fusion", "filter_parallelization", b"filter_parallelization", "inject_prefetch", b"inject_prefetch", "map_and_batch_fusion", b"map_and_batch_fusion", "map_and_filter_fusion", b"map_and_filter_fusion", "map_fusion", b"map_fusion", "map_parallelization", b"map_parallelization", "noop_elimination", b"noop_elimination", "optional_apply_default_optimizations", b"optional_apply_default_optimizations", "optional_filter_fusion", b"optional_filter_fusion", "optional_filter_parallelization", b"optional_filter_parallelization", "optional_inject_prefetch", b"optional_inject_prefetch", "optional_map_and_batch_fusion", b"optional_map_and_batch_fusion", "optional_map_and_filter_fusion", b"optional_map_and_filter_fusion", "optional_map_fusion", b"optional_map_fusion", "optional_map_parallelization", b"optional_map_parallelization", "optional_noop_elimination", b"optional_noop_elimination", "optional_parallel_batch", b"optional_parallel_batch", "optional_seq_interleave_prefetch", b"optional_seq_interleave_prefetch", "optional_shuffle_and_repeat_fusion", b"optional_shuffle_and_repeat_fusion", "parallel_batch", b"parallel_batch", "seq_interleave_prefetch", b"seq_interleave_prefetch", "shuffle_and_repeat_fusion", b"shuffle_and_repeat_fusion"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["apply_default_optimizations", b"apply_default_optimizations", "filter_fusion", b"filter_fusion", "filter_parallelization", b"filter_parallelization", "inject_prefetch", b"inject_prefetch", "map_and_batch_fusion", b"map_and_batch_fusion", "map_and_filter_fusion", b"map_and_filter_fusion", "map_fusion", b"map_fusion", "map_parallelization", b"map_parallelization", "noop_elimination", b"noop_elimination", "optional_apply_default_optimizations", b"optional_apply_default_optimizations", "optional_filter_fusion", b"optional_filter_fusion", "optional_filter_parallelization", b"optional_filter_parallelization", "optional_inject_prefetch", b"optional_inject_prefetch", "optional_map_and_batch_fusion", b"optional_map_and_batch_fusion", "optional_map_and_filter_fusion", b"optional_map_and_filter_fusion", "optional_map_fusion", b"optional_map_fusion", "optional_map_parallelization", b"optional_map_parallelization", "optional_noop_elimination", b"optional_noop_elimination", "optional_parallel_batch", b"optional_parallel_batch", "optional_seq_interleave_prefetch", b"optional_seq_interleave_prefetch", "optional_shuffle_and_repeat_fusion", b"optional_shuffle_and_repeat_fusion", "parallel_batch", b"parallel_batch", "seq_interleave_prefetch", b"seq_interleave_prefetch", "shuffle_and_repeat_fusion", b"shuffle_and_repeat_fusion"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["optional_apply_default_optimizations", b"optional_apply_default_optimizations"]) -> typing.Literal["apply_default_optimizations"] | None: ...
@typing.overload
@@ -262,6 +267,8 @@ class OptimizationOptions(google.protobuf.message.Message):
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["optional_parallel_batch", b"optional_parallel_batch"]) -> typing.Literal["parallel_batch"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["optional_seq_interleave_prefetch", b"optional_seq_interleave_prefetch"]) -> typing.Literal["seq_interleave_prefetch"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["optional_shuffle_and_repeat_fusion", b"optional_shuffle_and_repeat_fusion"]) -> typing.Literal["shuffle_and_repeat_fusion"] | None: ...
global___OptimizationOptions = OptimizationOptions
@@ -296,11 +303,13 @@ class Options(google.protobuf.message.Message):
"""Message stored with Dataset objects to control how datasets are processed and
optimized.
next: 9
next: 12
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DATASET_NAME_FIELD_NUMBER: builtins.int
FRAMEWORK_TYPE_FIELD_NUMBER: builtins.int
DETERMINISTIC_FIELD_NUMBER: builtins.int
AUTOTUNE_OPTIONS_FIELD_NUMBER: builtins.int
DISTRIBUTE_OPTIONS_FIELD_NUMBER: builtins.int
@@ -309,13 +318,20 @@ class Options(google.protobuf.message.Message):
THREADING_OPTIONS_FIELD_NUMBER: builtins.int
EXTERNAL_STATE_POLICY_FIELD_NUMBER: builtins.int
SYMBOLIC_CHECKPOINT_FIELD_NUMBER: builtins.int
WARM_START_FIELD_NUMBER: builtins.int
dataset_name: builtins.str
deterministic: builtins.bool
slack: builtins.bool
external_state_policy: global___ExternalStatePolicy.ValueType
symbolic_checkpoint: builtins.bool
warm_start: builtins.bool
@property
def framework_type(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""List of frameworks used to generate this dataset."""
@property
def autotune_options(self) -> global___AutotuneOptions:
"""The distribution strategy options associated with the dataset."""
"""The autotune options associated with the dataset."""
@property
def distribute_options(self) -> global___DistributeOptions:
@@ -332,6 +348,8 @@ class Options(google.protobuf.message.Message):
def __init__(
self,
*,
dataset_name: builtins.str | None = ...,
framework_type: collections.abc.Iterable[builtins.str] | None = ...,
deterministic: builtins.bool | None = ...,
autotune_options: global___AutotuneOptions | None = ...,
distribute_options: global___DistributeOptions | None = ...,
@@ -340,9 +358,12 @@ class Options(google.protobuf.message.Message):
threading_options: global___ThreadingOptions | None = ...,
external_state_policy: global___ExternalStatePolicy.ValueType | None = ...,
symbolic_checkpoint: builtins.bool | None = ...,
warm_start: builtins.bool | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options"]) -> None: ...
def HasField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "dataset_name", b"dataset_name", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_dataset_name", b"optional_dataset_name", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "optional_warm_start", b"optional_warm_start", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options", "warm_start", b"warm_start"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "dataset_name", b"dataset_name", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "framework_type", b"framework_type", "optimization_options", b"optimization_options", "optional_dataset_name", b"optional_dataset_name", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "optional_warm_start", b"optional_warm_start", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options", "warm_start", b"warm_start"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["optional_dataset_name", b"optional_dataset_name"]) -> typing.Literal["dataset_name"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["optional_deterministic", b"optional_deterministic"]) -> typing.Literal["deterministic"] | None: ...
@typing.overload
@@ -351,5 +372,7 @@ class Options(google.protobuf.message.Message):
def WhichOneof(self, oneof_group: typing.Literal["optional_slack", b"optional_slack"]) -> typing.Literal["slack"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["optional_symbolic_checkpoint", b"optional_symbolic_checkpoint"]) -> typing.Literal["symbolic_checkpoint"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["optional_warm_start", b"optional_warm_start"]) -> typing.Literal["warm_start"] | None: ...
global___Options = Options

View File

@@ -49,6 +49,7 @@ class GraphDebugInfo(google.protobuf.message.Message):
func: builtins.str | None = ...,
code: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["code", b"code", "col", b"col", "file_index", b"file_index", "func", b"func", "line", b"line"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["code", b"code", "col", b"col", "file_index", b"file_index", "func", b"func", "line", b"line"]) -> None: ...
@typing.final
@@ -58,16 +59,56 @@ class GraphDebugInfo(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
FILE_LINE_COLS_FIELD_NUMBER: builtins.int
FRAME_ID_FIELD_NUMBER: builtins.int
@property
def file_line_cols(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___GraphDebugInfo.FileLineCol]:
"""Each line in the stack trace."""
"""Deprecated."""
@property
def frame_id(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
def __init__(
self,
*,
file_line_cols: collections.abc.Iterable[global___GraphDebugInfo.FileLineCol] | None = ...,
frame_id: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["file_line_cols", b"file_line_cols"]) -> None: ...
def ClearField(self, field_name: typing.Literal["file_line_cols", b"file_line_cols", "frame_id", b"frame_id"]) -> None: ...
@typing.final
class FramesByIdEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.int
@property
def value(self) -> global___GraphDebugInfo.FileLineCol: ...
def __init__(
self,
*,
key: builtins.int | None = ...,
value: global___GraphDebugInfo.FileLineCol | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
@typing.final
class TracesByIdEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.int
@property
def value(self) -> global___GraphDebugInfo.StackTrace: ...
def __init__(
self,
*,
key: builtins.int | None = ...,
value: global___GraphDebugInfo.StackTrace | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
@typing.final
class TracesEntry(google.protobuf.message.Message):
@@ -84,24 +125,58 @@ class GraphDebugInfo(google.protobuf.message.Message):
key: builtins.str | None = ...,
value: global___GraphDebugInfo.StackTrace | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ...
def HasField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
@typing.final
class NameToTraceIdEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.str
value: builtins.int
def __init__(
self,
*,
key: builtins.str | None = ...,
value: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
FILES_FIELD_NUMBER: builtins.int
FRAMES_BY_ID_FIELD_NUMBER: builtins.int
TRACES_BY_ID_FIELD_NUMBER: builtins.int
TRACES_FIELD_NUMBER: builtins.int
NAME_TO_TRACE_ID_FIELD_NUMBER: builtins.int
@property
def files(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""This stores all the source code file names and can be indexed by the
`file_index`.
"""
@property
def frames_by_id(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, global___GraphDebugInfo.FileLineCol]:
"""Stack traces and frames are uniqueified during construction. These maps
index from the unique id for a frame/trace to the value.
"""
@property
def traces_by_id(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, global___GraphDebugInfo.StackTrace]: ...
@property
def traces(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___GraphDebugInfo.StackTrace]:
"""This maps a node name to a stack trace in the source code.
"""Deprecated."""
@property
def name_to_trace_id(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.int]:
"""This maps a node name to a trace id contained in `traces_by_id`.
The map key is a mangling of the containing function and op name with
syntax:
op.name '@' func_name
For ops in the top-level graph, the func_name is the empty string.
For ops in the top-level graph, the func_name is the empty string and hence
the `@` may be ommitted.
Note that op names are restricted to a small number of characters which
exclude '@', making it impossible to collide keys of this form. Function
names accept a much wider set of characters.
@@ -113,8 +188,11 @@ class GraphDebugInfo(google.protobuf.message.Message):
self,
*,
files: collections.abc.Iterable[builtins.str] | None = ...,
frames_by_id: collections.abc.Mapping[builtins.int, global___GraphDebugInfo.FileLineCol] | None = ...,
traces_by_id: collections.abc.Mapping[builtins.int, global___GraphDebugInfo.StackTrace] | None = ...,
traces: collections.abc.Mapping[builtins.str, global___GraphDebugInfo.StackTrace] | None = ...,
name_to_trace_id: collections.abc.Mapping[builtins.str, builtins.int] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["files", b"files", "traces", b"traces"]) -> None: ...
def ClearField(self, field_name: typing.Literal["files", b"files", "frames_by_id", b"frames_by_id", "name_to_trace_id", b"name_to_trace_id", "traces", b"traces", "traces_by_id", b"traces_by_id"]) -> None: ...
global___GraphDebugInfo = GraphDebugInfo

View File

@@ -11,6 +11,7 @@ import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import tensorflow.core.framework.function_pb2
import tensorflow.core.framework.graph_debug_info_pb2
import tensorflow.core.framework.node_def_pb2
import tensorflow.core.framework.versions_pb2
@@ -26,6 +27,7 @@ class GraphDef(google.protobuf.message.Message):
VERSIONS_FIELD_NUMBER: builtins.int
VERSION_FIELD_NUMBER: builtins.int
LIBRARY_FIELD_NUMBER: builtins.int
DEBUG_INFO_FIELD_NUMBER: builtins.int
version: builtins.int
"""Deprecated single version field; use versions above instead. Since all
GraphDef changes before "versions" was introduced were forward
@@ -70,6 +72,10 @@ class GraphDef(google.protobuf.message.Message):
function are ready.
"""
@property
def debug_info(self) -> tensorflow.core.framework.graph_debug_info_pb2.GraphDebugInfo:
"""Stack traces for the nodes in this graph."""
def __init__(
self,
*,
@@ -77,8 +83,9 @@ class GraphDef(google.protobuf.message.Message):
versions: tensorflow.core.framework.versions_pb2.VersionDef | None = ...,
version: builtins.int | None = ...,
library: tensorflow.core.framework.function_pb2.FunctionDefLibrary | None = ...,
debug_info: tensorflow.core.framework.graph_debug_info_pb2.GraphDebugInfo | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["library", b"library", "versions", b"versions"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["library", b"library", "node", b"node", "version", b"version", "versions", b"versions"]) -> None: ...
def HasField(self, field_name: typing.Literal["debug_info", b"debug_info", "library", b"library", "versions", b"versions"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["debug_info", b"debug_info", "library", b"library", "node", b"node", "version", b"version", "versions", b"versions"]) -> None: ...
global___GraphDef = GraphDef

View File

@@ -251,10 +251,14 @@ class ModelProto(google.protobuf.message.Message):
) -> None: ...
def ClearField(self, field_name: typing.Literal["algorithm", b"algorithm", "cpu_budget", b"cpu_budget", "model_input_time", b"model_input_time", "ram_budget", b"ram_budget"]) -> None: ...
DATASET_NAME_FIELD_NUMBER: builtins.int
NODES_FIELD_NUMBER: builtins.int
OUTPUT_FIELD_NUMBER: builtins.int
ID_COUNTER_FIELD_NUMBER: builtins.int
OPTIMIZATION_PARAMS_FIELD_NUMBER: builtins.int
GAP_TIMES_FIELD_NUMBER: builtins.int
dataset_name: builtins.str
"""User-defined name for the dataset. Empty if no name was set."""
output: builtins.int
"""ID of the output node of this model."""
id_counter: builtins.int
@@ -265,15 +269,19 @@ class ModelProto(google.protobuf.message.Message):
@property
def optimization_params(self) -> global___ModelProto.OptimizationParams: ...
@property
def gap_times(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
def __init__(
self,
*,
dataset_name: builtins.str | None = ...,
nodes: collections.abc.Mapping[builtins.int, global___ModelProto.Node] | None = ...,
output: builtins.int | None = ...,
id_counter: builtins.int | None = ...,
optimization_params: global___ModelProto.OptimizationParams | None = ...,
gap_times: collections.abc.Iterable[builtins.int] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["optimization_params", b"optimization_params"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["id_counter", b"id_counter", "nodes", b"nodes", "optimization_params", b"optimization_params", "output", b"output"]) -> None: ...
def ClearField(self, field_name: typing.Literal["dataset_name", b"dataset_name", "gap_times", b"gap_times", "id_counter", b"id_counter", "nodes", b"nodes", "optimization_params", b"optimization_params", "output", b"output"]) -> None: ...
global___ModelProto = ModelProto

View File

@@ -5,14 +5,21 @@ isort:skip_file
import builtins
import collections.abc
import sys
import typing
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import tensorflow.core.framework.graph_pb2
import tensorflow.core.framework.types_pb2
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
@@ -25,6 +32,30 @@ class OptimizedFunctionGraph(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _OptimizationSource:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _OptimizationSourceEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[OptimizedFunctionGraph._OptimizationSource.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
SOURCE_UNSPECIFIED: OptimizedFunctionGraph._OptimizationSource.ValueType # 0
AOT: OptimizedFunctionGraph._OptimizationSource.ValueType # 1
JIT: OptimizedFunctionGraph._OptimizationSource.ValueType # 2
class OptimizationSource(_OptimizationSource, metaclass=_OptimizationSourceEnumTypeWrapper):
"""Enum for distinguishing the origin where the proto is created.
AOT: proto is created in ahead-of-time environment, which can be different
from the environment where the graph is actually executed.
JIT: proto is created in just-in-time execution, which has the same
environment as the one the graph is actually executed.
"""
SOURCE_UNSPECIFIED: OptimizedFunctionGraph.OptimizationSource.ValueType # 0
AOT: OptimizedFunctionGraph.OptimizationSource.ValueType # 1
JIT: OptimizedFunctionGraph.OptimizationSource.ValueType # 2
@typing.final
class NodeNameToControlRetEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -46,12 +77,20 @@ class OptimizedFunctionGraph(google.protobuf.message.Message):
NODE_NAME_TO_CONTROL_RET_FIELD_NUMBER: builtins.int
RET_TYPES_FIELD_NUMBER: builtins.int
NUM_RETURN_NODES_FIELD_NUMBER: builtins.int
SOURCE_FIELD_NUMBER: builtins.int
OPTIMIZATION_TIME_USECS_FIELD_NUMBER: builtins.int
name: builtins.str
"""Function name. It can be a human-readable SignatureDef's method name, or a
FunctionDef name.
"""
num_return_nodes: builtins.int
"""Number of return nodes. This is an output of graph preprocessing."""
source: global___OptimizedFunctionGraph.OptimizationSource.ValueType
"""Indicates the source environment where this proto is generated."""
optimization_time_usecs: builtins.int
"""Time (in microseconds) spent on running the graph optimization passes for
this function.
"""
@property
def function_graph(self) -> tensorflow.core.framework.graph_pb2.GraphDef:
"""Optimized function graph."""
@@ -76,8 +115,14 @@ class OptimizedFunctionGraph(google.protobuf.message.Message):
node_name_to_control_ret: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
ret_types: collections.abc.Iterable[tensorflow.core.framework.types_pb2.DataType.ValueType] | None = ...,
num_return_nodes: builtins.int | None = ...,
source: global___OptimizedFunctionGraph.OptimizationSource.ValueType | None = ...,
optimization_time_usecs: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["function_graph", b"function_graph"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["function_graph", b"function_graph", "name", b"name", "node_name_to_control_ret", b"node_name_to_control_ret", "num_return_nodes", b"num_return_nodes", "ret_types", b"ret_types"]) -> None: ...
def HasField(self, field_name: typing.Literal["_optimization_time_usecs", b"_optimization_time_usecs", "_source", b"_source", "function_graph", b"function_graph", "optimization_time_usecs", b"optimization_time_usecs", "source", b"source"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_optimization_time_usecs", b"_optimization_time_usecs", "_source", b"_source", "function_graph", b"function_graph", "name", b"name", "node_name_to_control_ret", b"node_name_to_control_ret", "num_return_nodes", b"num_return_nodes", "optimization_time_usecs", b"optimization_time_usecs", "ret_types", b"ret_types", "source", b"source"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_optimization_time_usecs", b"_optimization_time_usecs"]) -> typing.Literal["optimization_time_usecs"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ...
global___OptimizedFunctionGraph = OptimizedFunctionGraph

View File

@@ -67,11 +67,18 @@ class _DataTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumT
"""5 exponent bits, 2 mantissa bits."""
DT_FLOAT8_E4M3FN: _DataType.ValueType # 25
"""4 exponent bits, 3 mantissa bits, finite-only, with"""
DT_FLOAT_REF: _DataType.ValueType # 101
DT_INT4: _DataType.ValueType # 29
"""2 NaNs (0bS1111111).
Do not use! These are only for parameters. Every enum above
should have a corresponding value below (verified by types_test).
TODO - b/299182407: Leaving room for remaining float8 types.
DT_FLOAT8_E4M3FNUZ = 26;
DT_FLOAT8_E4M3B11FNUZ = 27;
DT_FLOAT8_E5M2FNUZ = 28;
"""
DT_UINT4: _DataType.ValueType # 30
DT_FLOAT_REF: _DataType.ValueType # 101
"""Do not use! These are only for TF1's obsolete reference Variables.
Every enum above should have a corresponding value below (verified by
types_test).
"""
DT_DOUBLE_REF: _DataType.ValueType # 102
DT_INT32_REF: _DataType.ValueType # 103
@@ -97,6 +104,13 @@ class _DataTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumT
DT_UINT64_REF: _DataType.ValueType # 123
DT_FLOAT8_E5M2_REF: _DataType.ValueType # 124
DT_FLOAT8_E4M3FN_REF: _DataType.ValueType # 125
DT_INT4_REF: _DataType.ValueType # 129
"""TODO - b/299182407: Leaving room for remaining float8 types.
DT_FLOAT8_E4M3FNUZ_REF = 126;
DT_FLOAT8_E4M3B11FNUZ_REF = 127;
DT_FLOAT8_E5M2FNUZ_REF = 128;
"""
DT_UINT4_REF: _DataType.ValueType # 130
class DataType(_DataType, metaclass=_DataTypeEnumTypeWrapper):
"""(== suppress_warning documentation-presence ==)
@@ -146,11 +160,18 @@ DT_FLOAT8_E5M2: DataType.ValueType # 24
"""5 exponent bits, 2 mantissa bits."""
DT_FLOAT8_E4M3FN: DataType.ValueType # 25
"""4 exponent bits, 3 mantissa bits, finite-only, with"""
DT_FLOAT_REF: DataType.ValueType # 101
DT_INT4: DataType.ValueType # 29
"""2 NaNs (0bS1111111).
Do not use! These are only for parameters. Every enum above
should have a corresponding value below (verified by types_test).
TODO - b/299182407: Leaving room for remaining float8 types.
DT_FLOAT8_E4M3FNUZ = 26;
DT_FLOAT8_E4M3B11FNUZ = 27;
DT_FLOAT8_E5M2FNUZ = 28;
"""
DT_UINT4: DataType.ValueType # 30
DT_FLOAT_REF: DataType.ValueType # 101
"""Do not use! These are only for TF1's obsolete reference Variables.
Every enum above should have a corresponding value below (verified by
types_test).
"""
DT_DOUBLE_REF: DataType.ValueType # 102
DT_INT32_REF: DataType.ValueType # 103
@@ -176,6 +197,13 @@ DT_UINT32_REF: DataType.ValueType # 122
DT_UINT64_REF: DataType.ValueType # 123
DT_FLOAT8_E5M2_REF: DataType.ValueType # 124
DT_FLOAT8_E4M3FN_REF: DataType.ValueType # 125
DT_INT4_REF: DataType.ValueType # 129
"""TODO - b/299182407: Leaving room for remaining float8 types.
DT_FLOAT8_E4M3FNUZ_REF = 126;
DT_FLOAT8_E4M3B11FNUZ_REF = 127;
DT_FLOAT8_E5M2FNUZ_REF = 128;
"""
DT_UINT4_REF: DataType.ValueType # 130
global___DataType = DataType
@typing.final

View File

@@ -102,6 +102,9 @@ class JobDef(google.protobuf.message.Message):
If the `name` field contains "worker", and the `tasks` map contains a
mapping from 7 to "example.org:2222", then the device prefix
"/job:worker/task:7" will be assigned to "example.org:2222".
If a job has multiple replicas, host-ports will be comma-delimited, with
one entry for each replica.
"""
def __init__(

View File

@@ -52,8 +52,8 @@ class GPUOptions(google.protobuf.message.Message):
"""Per "virtual" device memory limit, in MB. The number of elements in
the list is the number of virtual devices to create on the
corresponding visible GPU (see "virtual_devices" below).
If empty, it will create single virtual device taking all available
memory from the device.
If empty and `num_virtual_devices_per_gpu` is not set, it will create
single virtual device taking all available memory from the device.
For the concept of "visible" and "virtual" GPU, see the comments for
"visible_device_list" above for more information.
@@ -91,6 +91,7 @@ class GPUOptions(google.protobuf.message.Message):
def ClearField(self, field_name: typing.Literal["device_ordinal", b"device_ordinal", "memory_limit_mb", b"memory_limit_mb", "priority", b"priority"]) -> None: ...
VIRTUAL_DEVICES_FIELD_NUMBER: builtins.int
NUM_VIRTUAL_DEVICES_PER_GPU_FIELD_NUMBER: builtins.int
USE_UNIFIED_MEMORY_FIELD_NUMBER: builtins.int
NUM_DEV_TO_DEV_COPY_STREAMS_FIELD_NUMBER: builtins.int
COLLECTIVE_RING_ORDER_FIELD_NUMBER: builtins.int
@@ -103,6 +104,13 @@ class GPUOptions(google.protobuf.message.Message):
DISALLOW_RETRY_ON_ALLOCATION_FAILURE_FIELD_NUMBER: builtins.int
GPU_HOST_MEM_LIMIT_IN_MB_FIELD_NUMBER: builtins.int
GPU_HOST_MEM_DISALLOW_GROWTH_FIELD_NUMBER: builtins.int
GPU_SYSTEM_MEMORY_SIZE_IN_MB_FIELD_NUMBER: builtins.int
num_virtual_devices_per_gpu: builtins.int
"""The number of virtual devices to create on each visible GPU. The
available memory will be split equally among all virtual devices. If the
field `memory_limit_mb` in `VirtualDevices` is not empty, this field will
be ignored.
"""
use_unified_memory: builtins.bool
"""If true, uses CUDA unified memory for memory allocations. If
per_process_gpu_memory_fraction option is greater than 1.0, then unified
@@ -185,6 +193,13 @@ class GPUOptions(google.protobuf.message.Message):
gpu_host_mem_limit_in_mb, because the default GPU host memory limit is
quite high.
"""
gpu_system_memory_size_in_mb: builtins.int
"""Memory limit for gpu system. This can also be set by
TF_DEVICE_MIN_SYS_MEMORY_IN_MB, which takes precedence over
gpu_system_memory_size_in_mb. With this, user can configure the gpu
system memory size for better resource estimation of multi-tenancy(one
gpu with multiple model) use case.
"""
@property
def virtual_devices(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___GPUOptions.Experimental.VirtualDevices]:
"""The multi virtual device settings. If empty (not set), it will create
@@ -231,6 +246,7 @@ class GPUOptions(google.protobuf.message.Message):
self,
*,
virtual_devices: collections.abc.Iterable[global___GPUOptions.Experimental.VirtualDevices] | None = ...,
num_virtual_devices_per_gpu: builtins.int | None = ...,
use_unified_memory: builtins.bool | None = ...,
num_dev_to_dev_copy_streams: builtins.int | None = ...,
collective_ring_order: builtins.str | None = ...,
@@ -243,8 +259,9 @@ class GPUOptions(google.protobuf.message.Message):
disallow_retry_on_allocation_failure: builtins.bool | None = ...,
gpu_host_mem_limit_in_mb: builtins.float | None = ...,
gpu_host_mem_disallow_growth: builtins.bool | None = ...,
gpu_system_memory_size_in_mb: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["collective_ring_order", b"collective_ring_order", "disallow_retry_on_allocation_failure", b"disallow_retry_on_allocation_failure", "gpu_host_mem_disallow_growth", b"gpu_host_mem_disallow_growth", "gpu_host_mem_limit_in_mb", b"gpu_host_mem_limit_in_mb", "internal_fragmentation_fraction", b"internal_fragmentation_fraction", "kernel_tracker_max_bytes", b"kernel_tracker_max_bytes", "kernel_tracker_max_interval", b"kernel_tracker_max_interval", "kernel_tracker_max_pending", b"kernel_tracker_max_pending", "num_dev_to_dev_copy_streams", b"num_dev_to_dev_copy_streams", "timestamped_allocator", b"timestamped_allocator", "use_cuda_malloc_async", b"use_cuda_malloc_async", "use_unified_memory", b"use_unified_memory", "virtual_devices", b"virtual_devices"]) -> None: ...
def ClearField(self, field_name: typing.Literal["collective_ring_order", b"collective_ring_order", "disallow_retry_on_allocation_failure", b"disallow_retry_on_allocation_failure", "gpu_host_mem_disallow_growth", b"gpu_host_mem_disallow_growth", "gpu_host_mem_limit_in_mb", b"gpu_host_mem_limit_in_mb", "gpu_system_memory_size_in_mb", b"gpu_system_memory_size_in_mb", "internal_fragmentation_fraction", b"internal_fragmentation_fraction", "kernel_tracker_max_bytes", b"kernel_tracker_max_bytes", "kernel_tracker_max_interval", b"kernel_tracker_max_interval", "kernel_tracker_max_pending", b"kernel_tracker_max_pending", "num_dev_to_dev_copy_streams", b"num_dev_to_dev_copy_streams", "num_virtual_devices_per_gpu", b"num_virtual_devices_per_gpu", "timestamped_allocator", b"timestamped_allocator", "use_cuda_malloc_async", b"use_cuda_malloc_async", "use_unified_memory", b"use_unified_memory", "virtual_devices", b"virtual_devices"]) -> None: ...
PER_PROCESS_GPU_MEMORY_FRACTION_FIELD_NUMBER: builtins.int
ALLOW_GROWTH_FIELD_NUMBER: builtins.int
@@ -695,10 +712,16 @@ class ConfigProto(google.protobuf.message.Message):
DISABLE_OUTPUT_PARTITION_GRAPHS_FIELD_NUMBER: builtins.int
XLA_FUSION_AUTOTUNER_THRESH_FIELD_NUMBER: builtins.int
USE_TFRT_FIELD_NUMBER: builtins.int
ENABLE_MULTI_HOST_FIELD_NUMBER: builtins.int
BACKEND_SERVER_PORT_FIELD_NUMBER: builtins.int
TARGET_TPU_FIELD_NUMBER: builtins.int
TARGET_GPU_FIELD_NUMBER: builtins.int
STREAM_MERGE_THRESHOLD_FIELD_NUMBER: builtins.int
DISABLE_FUNCTIONAL_OPS_LOWERING_FIELD_NUMBER: builtins.int
XLA_PREFER_SINGLE_GRAPH_CLUSTER_FIELD_NUMBER: builtins.int
COORDINATION_CONFIG_FIELD_NUMBER: builtins.int
DISABLE_OPTIMIZE_FOR_STATIC_GRAPH_FIELD_NUMBER: builtins.int
DISABLE_EAGER_EXECUTOR_STREAMING_ENQUEUE_FIELD_NUMBER: builtins.int
collective_group_leader: builtins.str
"""Task name for group resolution."""
executor_type: builtins.str
@@ -801,6 +824,23 @@ class ConfigProto(google.protobuf.message.Message):
"""
use_tfrt: builtins.bool
"""Whether runtime execution uses TFRT."""
enable_multi_host: builtins.bool
"""If true, use Pathways with TFRT API for multi host support."""
backend_server_port: builtins.int
"""Port for the Pathways server. Ignored if enable_multi_host=false."""
target_tpu: builtins.bool
"""If true, TFRT will use TPU specific compiler passes and perform TPU
specific initialization.
"""
target_gpu: builtins.bool
"""If true, TFRT will use GPU specific compiler passes and perform GPU
specific initialization.
"""
stream_merge_threshold: builtins.int
"""The threshold to merge small streams in TFRT. The stream with cost
smaller than the threshold will be merged. Setting it to value 1
disables all merges.
"""
disable_functional_ops_lowering: builtins.bool
"""Whether functional control flow op lowering should be disabled. This is
useful when executing within a portable runtime where control flow op
@@ -821,6 +861,11 @@ class ConfigProto(google.protobuf.message.Message):
This option is meant to replace `optimize_for_static_graph` and it
aims to negate its value.
"""
disable_eager_executor_streaming_enqueue: builtins.bool
"""Whether eager remote execution will stream all the function calls or
allow them to happen in parallel. When true, streaming execution is
disabled, and parallel execution is allowed.
"""
@property
def session_metadata(self) -> global___SessionMetadata:
"""Metadata about the session.
@@ -856,13 +901,19 @@ class ConfigProto(google.protobuf.message.Message):
disable_output_partition_graphs: builtins.bool | None = ...,
xla_fusion_autotuner_thresh: builtins.int | None = ...,
use_tfrt: builtins.bool | None = ...,
enable_multi_host: builtins.bool | None = ...,
backend_server_port: builtins.int | None = ...,
target_tpu: builtins.bool | None = ...,
target_gpu: builtins.bool | None = ...,
stream_merge_threshold: builtins.int | None = ...,
disable_functional_ops_lowering: builtins.bool | None = ...,
xla_prefer_single_graph_cluster: builtins.bool | None = ...,
coordination_config: tensorflow.tsl.protobuf.coordination_config_pb2.CoordinationServiceConfig | None = ...,
disable_optimize_for_static_graph: builtins.bool | None = ...,
disable_eager_executor_streaming_enqueue: builtins.bool | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["coordination_config", b"coordination_config", "session_metadata", b"session_metadata"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["collective_deterministic_sequential_execution", b"collective_deterministic_sequential_execution", "collective_group_leader", b"collective_group_leader", "collective_nccl", b"collective_nccl", "coordination_config", b"coordination_config", "disable_functional_ops_lowering", b"disable_functional_ops_lowering", "disable_optimize_for_static_graph", b"disable_optimize_for_static_graph", "disable_output_partition_graphs", b"disable_output_partition_graphs", "disable_thread_spinning", b"disable_thread_spinning", "enable_mlir_bridge", b"enable_mlir_bridge", "enable_mlir_graph_optimization", b"enable_mlir_graph_optimization", "executor_type", b"executor_type", "mlir_bridge_rollout", b"mlir_bridge_rollout", "optimize_for_static_graph", b"optimize_for_static_graph", "recv_buf_max_chunk", b"recv_buf_max_chunk", "session_metadata", b"session_metadata", "share_cluster_devices_in_session", b"share_cluster_devices_in_session", "share_session_state_in_clusterspec_propagation", b"share_session_state_in_clusterspec_propagation", "use_numa_affinity", b"use_numa_affinity", "use_tfrt", b"use_tfrt", "xla_fusion_autotuner_thresh", b"xla_fusion_autotuner_thresh", "xla_prefer_single_graph_cluster", b"xla_prefer_single_graph_cluster"]) -> None: ...
def ClearField(self, field_name: typing.Literal["backend_server_port", b"backend_server_port", "collective_deterministic_sequential_execution", b"collective_deterministic_sequential_execution", "collective_group_leader", b"collective_group_leader", "collective_nccl", b"collective_nccl", "coordination_config", b"coordination_config", "disable_eager_executor_streaming_enqueue", b"disable_eager_executor_streaming_enqueue", "disable_functional_ops_lowering", b"disable_functional_ops_lowering", "disable_optimize_for_static_graph", b"disable_optimize_for_static_graph", "disable_output_partition_graphs", b"disable_output_partition_graphs", "disable_thread_spinning", b"disable_thread_spinning", "enable_mlir_bridge", b"enable_mlir_bridge", "enable_mlir_graph_optimization", b"enable_mlir_graph_optimization", "enable_multi_host", b"enable_multi_host", "executor_type", b"executor_type", "mlir_bridge_rollout", b"mlir_bridge_rollout", "optimize_for_static_graph", b"optimize_for_static_graph", "recv_buf_max_chunk", b"recv_buf_max_chunk", "session_metadata", b"session_metadata", "share_cluster_devices_in_session", b"share_cluster_devices_in_session", "share_session_state_in_clusterspec_propagation", b"share_session_state_in_clusterspec_propagation", "stream_merge_threshold", b"stream_merge_threshold", "target_gpu", b"target_gpu", "target_tpu", b"target_tpu", "use_numa_affinity", b"use_numa_affinity", "use_tfrt", b"use_tfrt", "xla_fusion_autotuner_thresh", b"xla_fusion_autotuner_thresh", "xla_prefer_single_graph_cluster", b"xla_prefer_single_graph_cluster"]) -> None: ...
DEVICE_COUNT_FIELD_NUMBER: builtins.int
INTRA_OP_PARALLELISM_THREADS_FIELD_NUMBER: builtins.int

View File

@@ -12,8 +12,8 @@ import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import tensorflow.core.framework.graph_debug_info_pb2
import tensorflow.core.framework.tensor_pb2
import tensorflow.core.protobuf.graph_debug_info_pb2
if sys.version_info >= (3, 10):
import typing as typing_extensions
@@ -288,7 +288,7 @@ class StackFrameWithId(google.protobuf.message.Message):
id: builtins.str
"""A unique ID for the stack frame: A UUID-like string."""
@property
def file_line_col(self) -> tensorflow.core.protobuf.graph_debug_info_pb2.GraphDebugInfo.FileLineCol:
def file_line_col(self) -> tensorflow.core.framework.graph_debug_info_pb2.GraphDebugInfo.FileLineCol:
"""Stack frame, i.e., a frame of a stack trace, containing information
regarding the file name, line number, function name, code content
of the line, and column number (if available).
@@ -298,7 +298,7 @@ class StackFrameWithId(google.protobuf.message.Message):
self,
*,
id: builtins.str | None = ...,
file_line_col: tensorflow.core.protobuf.graph_debug_info_pb2.GraphDebugInfo.FileLineCol | None = ...,
file_line_col: tensorflow.core.framework.graph_debug_info_pb2.GraphDebugInfo.FileLineCol | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["file_line_col", b"file_line_col"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["file_line_col", b"file_line_col", "id", b"id"]) -> None: ...

View File

@@ -40,7 +40,9 @@ class FingerprintDef(google.protobuf.message.Message):
"""Hash of the checkpoint."""
@property
def version(self) -> tensorflow.core.framework.versions_pb2.VersionDef:
"""Version specification of the fingerprint."""
"""Version specification of the fingerprint.
TODO(b/290068219): add USM version when GA
"""
def __init__(
self,

View File

@@ -13,6 +13,7 @@ import google.protobuf.internal.containers
import google.protobuf.message
import tensorflow.core.framework.graph_pb2
import tensorflow.core.framework.op_def_pb2
import tensorflow.core.framework.tensor_pb2
import tensorflow.core.framework.tensor_shape_pb2
import tensorflow.core.framework.types_pb2
import tensorflow.core.protobuf.saved_object_graph_pb2
@@ -621,9 +622,28 @@ class SignatureDef(google.protobuf.message.Message):
def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
@typing.final
class DefaultsEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.str
@property
def value(self) -> tensorflow.core.framework.tensor_pb2.TensorProto: ...
def __init__(
self,
*,
key: builtins.str | None = ...,
value: tensorflow.core.framework.tensor_pb2.TensorProto | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
INPUTS_FIELD_NUMBER: builtins.int
OUTPUTS_FIELD_NUMBER: builtins.int
METHOD_NAME_FIELD_NUMBER: builtins.int
DEFAULTS_FIELD_NUMBER: builtins.int
method_name: builtins.str
"""Extensible method_name information enabling third-party users to mark a
SignatureDef as supporting a particular method. This enables producers and
@@ -642,14 +662,19 @@ class SignatureDef(google.protobuf.message.Message):
def outputs(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___TensorInfo]:
"""Named output parameters."""
@property
def defaults(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, tensorflow.core.framework.tensor_pb2.TensorProto]:
"""Named input to corresponding default values if any."""
def __init__(
self,
*,
inputs: collections.abc.Mapping[builtins.str, global___TensorInfo] | None = ...,
outputs: collections.abc.Mapping[builtins.str, global___TensorInfo] | None = ...,
method_name: builtins.str | None = ...,
defaults: collections.abc.Mapping[builtins.str, tensorflow.core.framework.tensor_pb2.TensorProto] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["inputs", b"inputs", "method_name", b"method_name", "outputs", b"outputs"]) -> None: ...
def ClearField(self, field_name: typing.Literal["defaults", b"defaults", "inputs", b"inputs", "method_name", b"method_name", "outputs", b"outputs"]) -> None: ...
global___SignatureDef = SignatureDef

View File

@@ -266,6 +266,7 @@ class RewriterConfig(google.protobuf.message.Message):
AUTO_MIXED_PRECISION_ONEDNN_BFLOAT16_FIELD_NUMBER: builtins.int
AUTO_MIXED_PRECISION_CPU_FIELD_NUMBER: builtins.int
DISABLE_META_OPTIMIZER_FIELD_NUMBER: builtins.int
DISABLE_TFG_OPTIMIZER_FIELD_NUMBER: builtins.int
USE_PLUGIN_OPTIMIZERS_FIELD_NUMBER: builtins.int
EXPERIMENTAL_CONDITIONAL_CODE_MOTION_FIELD_NUMBER: builtins.int
META_OPTIMIZER_ITERATIONS_FIELD_NUMBER: builtins.int
@@ -359,6 +360,8 @@ class RewriterConfig(google.protobuf.message.Message):
"""
disable_meta_optimizer: builtins.bool
"""Disable the entire meta optimizer (off by default)."""
disable_tfg_optimizer: builtins.bool
"""Disable the TFG optimizer (off by default)."""
use_plugin_optimizers: global___RewriterConfig.Toggle.ValueType
"""Optimizers registered by plugin (default is ON)"""
experimental_conditional_code_motion: global___RewriterConfig.Toggle.ValueType
@@ -471,6 +474,7 @@ class RewriterConfig(google.protobuf.message.Message):
auto_mixed_precision_onednn_bfloat16: global___RewriterConfig.Toggle.ValueType | None = ...,
auto_mixed_precision_cpu: global___RewriterConfig.Toggle.ValueType | None = ...,
disable_meta_optimizer: builtins.bool | None = ...,
disable_tfg_optimizer: builtins.bool | None = ...,
use_plugin_optimizers: global___RewriterConfig.Toggle.ValueType | None = ...,
experimental_conditional_code_motion: global___RewriterConfig.Toggle.ValueType | None = ...,
meta_optimizer_iterations: global___RewriterConfig.NumIterationsType.ValueType | None = ...,
@@ -489,6 +493,6 @@ class RewriterConfig(google.protobuf.message.Message):
post_optimization_verifier_config: tensorflow.core.protobuf.verifier_config_pb2.VerifierConfig | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["auto_parallel", b"auto_parallel", "inter_optimizer_verifier_config", b"inter_optimizer_verifier_config", "post_optimization_verifier_config", b"post_optimization_verifier_config", "scoped_allocator_opts", b"scoped_allocator_opts"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["arithmetic_optimization", b"arithmetic_optimization", "auto_mixed_precision", b"auto_mixed_precision", "auto_mixed_precision_cpu", b"auto_mixed_precision_cpu", "auto_mixed_precision_mkl", b"auto_mixed_precision_mkl", "auto_mixed_precision_onednn_bfloat16", b"auto_mixed_precision_onednn_bfloat16", "auto_parallel", b"auto_parallel", "common_subgraph_elimination", b"common_subgraph_elimination", "constant_folding", b"constant_folding", "cpu_layout_conversion", b"cpu_layout_conversion", "custom_optimizers", b"custom_optimizers", "debug_stripper", b"debug_stripper", "dependency_optimization", b"dependency_optimization", "disable_meta_optimizer", b"disable_meta_optimizer", "disable_model_pruning", b"disable_model_pruning", "experimental_conditional_code_motion", b"experimental_conditional_code_motion", "experimental_disable_compressed_tensor_optimization", b"experimental_disable_compressed_tensor_optimization", "experimental_disable_folding_quantization_emulation", b"experimental_disable_folding_quantization_emulation", "fail_on_optimizer_errors", b"fail_on_optimizer_errors", "function_optimization", b"function_optimization", "implementation_selector", b"implementation_selector", "inter_optimizer_verifier_config", b"inter_optimizer_verifier_config", "layout_optimizer", b"layout_optimizer", "loop_optimization", b"loop_optimization", "memory_optimization", b"memory_optimization", "memory_optimizer_target_node_name_scope", b"memory_optimizer_target_node_name_scope", "meta_optimizer_iterations", b"meta_optimizer_iterations", "meta_optimizer_timeout_ms", b"meta_optimizer_timeout_ms", "min_graph_nodes", b"min_graph_nodes", "optimizers", b"optimizers", "pin_to_host_optimization", b"pin_to_host_optimization", "post_optimization_verifier_config", b"post_optimization_verifier_config", "remapping", b"remapping", "scoped_allocator_optimization", b"scoped_allocator_optimization", "scoped_allocator_opts", b"scoped_allocator_opts", "shape_optimization", b"shape_optimization", "use_plugin_optimizers", b"use_plugin_optimizers"]) -> None: ...
def ClearField(self, field_name: typing.Literal["arithmetic_optimization", b"arithmetic_optimization", "auto_mixed_precision", b"auto_mixed_precision", "auto_mixed_precision_cpu", b"auto_mixed_precision_cpu", "auto_mixed_precision_mkl", b"auto_mixed_precision_mkl", "auto_mixed_precision_onednn_bfloat16", b"auto_mixed_precision_onednn_bfloat16", "auto_parallel", b"auto_parallel", "common_subgraph_elimination", b"common_subgraph_elimination", "constant_folding", b"constant_folding", "cpu_layout_conversion", b"cpu_layout_conversion", "custom_optimizers", b"custom_optimizers", "debug_stripper", b"debug_stripper", "dependency_optimization", b"dependency_optimization", "disable_meta_optimizer", b"disable_meta_optimizer", "disable_model_pruning", b"disable_model_pruning", "disable_tfg_optimizer", b"disable_tfg_optimizer", "experimental_conditional_code_motion", b"experimental_conditional_code_motion", "experimental_disable_compressed_tensor_optimization", b"experimental_disable_compressed_tensor_optimization", "experimental_disable_folding_quantization_emulation", b"experimental_disable_folding_quantization_emulation", "fail_on_optimizer_errors", b"fail_on_optimizer_errors", "function_optimization", b"function_optimization", "implementation_selector", b"implementation_selector", "inter_optimizer_verifier_config", b"inter_optimizer_verifier_config", "layout_optimizer", b"layout_optimizer", "loop_optimization", b"loop_optimization", "memory_optimization", b"memory_optimization", "memory_optimizer_target_node_name_scope", b"memory_optimizer_target_node_name_scope", "meta_optimizer_iterations", b"meta_optimizer_iterations", "meta_optimizer_timeout_ms", b"meta_optimizer_timeout_ms", "min_graph_nodes", b"min_graph_nodes", "optimizers", b"optimizers", "pin_to_host_optimization", b"pin_to_host_optimization", "post_optimization_verifier_config", b"post_optimization_verifier_config", "remapping", b"remapping", "scoped_allocator_optimization", b"scoped_allocator_optimization", "scoped_allocator_opts", b"scoped_allocator_opts", "shape_optimization", b"shape_optimization", "use_plugin_optimizers", b"use_plugin_optimizers"]) -> None: ...
global___RewriterConfig = RewriterConfig

View File

@@ -17,7 +17,7 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class DispatcherConfig(google.protobuf.message.Message):
"""Configuration for a tf.data service DispatchServer.
Next id: 11
Next id: 13
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -30,8 +30,10 @@ class DispatcherConfig(google.protobuf.message.Message):
DEPLOYMENT_MODE_FIELD_NUMBER: builtins.int
JOB_GC_CHECK_INTERVAL_MS_FIELD_NUMBER: builtins.int
JOB_GC_TIMEOUT_MS_FIELD_NUMBER: builtins.int
GC_DYNAMIC_SHARDING_JOBS_FIELD_NUMBER: builtins.int
CLIENT_TIMEOUT_MS_FIELD_NUMBER: builtins.int
WORKER_TIMEOUT_MS_FIELD_NUMBER: builtins.int
WORKER_MAX_CONCURRENT_SNAPSHOTS_FIELD_NUMBER: builtins.int
port: builtins.int
"""The port for the dispatcher to bind to. A value of 0 indicates that the
dispatcher may bind to any available port.
@@ -59,7 +61,15 @@ class DispatcherConfig(google.protobuf.message.Message):
"""How long a job needs to be unused before it becomes a candidate for garbage
collection. A value of -1 indicates that jobs should never be garbage
collected. A value of 0 indicates that the decision should be left up to
the runtime.
the runtime. Note: This does not apply to dynamic sharding unless users
explicitly opt-in by enabling `gc_dynamic_sharding_jobs` below.
"""
gc_dynamic_sharding_jobs: builtins.bool
"""Whether dynamically sharded jobs should be eligible for garbage collection.
These jobs are not garbage collected by default, since if a job is garbage
collected and then re-created, it will revisit all data from the start. If
revisiting data is acceptible and you want automatic reclamation of
iterator memory, set `gc_dynamic_sharding_jobs` to `true`.
"""
client_timeout_ms: builtins.int
"""How long to wait before garbage-collecting a client that hasn't
@@ -70,6 +80,12 @@ class DispatcherConfig(google.protobuf.message.Message):
"""How long to wait for a worker to heartbeat before considering it missing.
A value of 0 indicates that the timeout should be left to the runtime.
"""
worker_max_concurrent_snapshots: builtins.int
"""The maximum number of snapshots that a worker can concurrently process at a
given point in time. This is a tradeoff between worker resource usage and
snapshot wall time. A value of 0 indicates that the decision should be left
up to the runtime.
"""
@property
def worker_addresses(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""(Optional.) If the job uses auto-sharding, it needs to specify a fixed list
@@ -89,17 +105,19 @@ class DispatcherConfig(google.protobuf.message.Message):
deployment_mode: tensorflow.core.protobuf.data_service_pb2.DeploymentMode.ValueType | None = ...,
job_gc_check_interval_ms: builtins.int | None = ...,
job_gc_timeout_ms: builtins.int | None = ...,
gc_dynamic_sharding_jobs: builtins.bool | None = ...,
client_timeout_ms: builtins.int | None = ...,
worker_timeout_ms: builtins.int | None = ...,
worker_max_concurrent_snapshots: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["client_timeout_ms", b"client_timeout_ms", "deployment_mode", b"deployment_mode", "fault_tolerant_mode", b"fault_tolerant_mode", "job_gc_check_interval_ms", b"job_gc_check_interval_ms", "job_gc_timeout_ms", b"job_gc_timeout_ms", "port", b"port", "protocol", b"protocol", "work_dir", b"work_dir", "worker_addresses", b"worker_addresses", "worker_timeout_ms", b"worker_timeout_ms"]) -> None: ...
def ClearField(self, field_name: typing.Literal["client_timeout_ms", b"client_timeout_ms", "deployment_mode", b"deployment_mode", "fault_tolerant_mode", b"fault_tolerant_mode", "gc_dynamic_sharding_jobs", b"gc_dynamic_sharding_jobs", "job_gc_check_interval_ms", b"job_gc_check_interval_ms", "job_gc_timeout_ms", b"job_gc_timeout_ms", "port", b"port", "protocol", b"protocol", "work_dir", b"work_dir", "worker_addresses", b"worker_addresses", "worker_max_concurrent_snapshots", b"worker_max_concurrent_snapshots", "worker_timeout_ms", b"worker_timeout_ms"]) -> None: ...
global___DispatcherConfig = DispatcherConfig
@typing.final
class WorkerConfig(google.protobuf.message.Message):
"""Configuration for a tf.data service WorkerServer.
Next id: 12
Next id: 13
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -114,6 +132,7 @@ class WorkerConfig(google.protobuf.message.Message):
DATA_TRANSFER_PROTOCOL_FIELD_NUMBER: builtins.int
DATA_TRANSFER_ADDRESS_FIELD_NUMBER: builtins.int
CROSS_TRAINER_CACHE_SIZE_BYTES_FIELD_NUMBER: builtins.int
SNAPSHOT_MAX_CHUNK_SIZE_BYTES_FIELD_NUMBER: builtins.int
SHUTDOWN_QUIET_PERIOD_MS_FIELD_NUMBER: builtins.int
port: builtins.int
"""The port for the worker to bind to. A value of 0 indicates that the
@@ -148,6 +167,10 @@ class WorkerConfig(google.protobuf.message.Message):
"""Maximum size of the cross-trainer cache in bytes. If enabled, make sure
your training job provides sufficient memory resources.
"""
snapshot_max_chunk_size_bytes: builtins.int
"""The maximum size of a distributed snapshot chunk file. A value of 0
indicates that the decision should be left up to the runtime.
"""
shutdown_quiet_period_ms: builtins.int
"""When shutting down a worker, how long to wait for the gRPC server to
process the final requests. This is used to achieve clean shutdown in unit
@@ -174,8 +197,9 @@ class WorkerConfig(google.protobuf.message.Message):
data_transfer_protocol: builtins.str | None = ...,
data_transfer_address: builtins.str | None = ...,
cross_trainer_cache_size_bytes: builtins.int | None = ...,
snapshot_max_chunk_size_bytes: builtins.int | None = ...,
shutdown_quiet_period_ms: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["cross_trainer_cache_size_bytes", b"cross_trainer_cache_size_bytes", "data_transfer_address", b"data_transfer_address", "data_transfer_protocol", b"data_transfer_protocol", "dispatcher_address", b"dispatcher_address", "dispatcher_timeout_ms", b"dispatcher_timeout_ms", "heartbeat_interval_ms", b"heartbeat_interval_ms", "port", b"port", "protocol", b"protocol", "shutdown_quiet_period_ms", b"shutdown_quiet_period_ms", "worker_address", b"worker_address", "worker_tags", b"worker_tags"]) -> None: ...
def ClearField(self, field_name: typing.Literal["cross_trainer_cache_size_bytes", b"cross_trainer_cache_size_bytes", "data_transfer_address", b"data_transfer_address", "data_transfer_protocol", b"data_transfer_protocol", "dispatcher_address", b"dispatcher_address", "dispatcher_timeout_ms", b"dispatcher_timeout_ms", "heartbeat_interval_ms", b"heartbeat_interval_ms", "port", b"port", "protocol", b"protocol", "shutdown_quiet_period_ms", b"shutdown_quiet_period_ms", "snapshot_max_chunk_size_bytes", b"snapshot_max_chunk_size_bytes", "worker_address", b"worker_address", "worker_tags", b"worker_tags"]) -> None: ...
global___WorkerConfig = WorkerConfig

View File

@@ -0,0 +1,13 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
Add a dummy package name. Having no package like
core/lib/core/error_codes.proto, or having tensorflow like
tsl/protobuf/error_codes.proto, results in name collision errors in generated
code for some users that use JS through J2CL.
"""
import google.protobuf.descriptor
from tensorflow.tsl.protobuf.status_pb2 import StatusProto as StatusProto
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor

View File

@@ -67,6 +67,8 @@ class StructuredValue(google.protobuf.message.Message):
TUPLE_VALUE_FIELD_NUMBER: builtins.int
DICT_VALUE_FIELD_NUMBER: builtins.int
NAMED_TUPLE_VALUE_FIELD_NUMBER: builtins.int
TENSOR_VALUE_FIELD_NUMBER: builtins.int
NUMPY_VALUE_FIELD_NUMBER: builtins.int
float64_value: builtins.float
"""Represents a double-precision floating-point value (a Python `float`)."""
int64_value: builtins.int
@@ -121,6 +123,14 @@ class StructuredValue(google.protobuf.message.Message):
def named_tuple_value(self) -> global___NamedTupleValue:
"""Represents Python's namedtuple."""
@property
def tensor_value(self) -> tensorflow.core.framework.tensor_pb2.TensorProto:
"""Represents a value for tf.Tensor."""
@property
def numpy_value(self) -> tensorflow.core.framework.tensor_pb2.TensorProto:
"""Represents a value for np.ndarray."""
def __init__(
self,
*,
@@ -138,10 +148,12 @@ class StructuredValue(google.protobuf.message.Message):
tuple_value: global___TupleValue | None = ...,
dict_value: global___DictValue | None = ...,
named_tuple_value: global___NamedTupleValue | None = ...,
tensor_value: tensorflow.core.framework.tensor_pb2.TensorProto | None = ...,
numpy_value: tensorflow.core.framework.tensor_pb2.TensorProto | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["bool_value", b"bool_value", "bounded_tensor_spec_value", b"bounded_tensor_spec_value", "dict_value", b"dict_value", "float64_value", b"float64_value", "int64_value", b"int64_value", "kind", b"kind", "list_value", b"list_value", "named_tuple_value", b"named_tuple_value", "none_value", b"none_value", "string_value", b"string_value", "tensor_dtype_value", b"tensor_dtype_value", "tensor_shape_value", b"tensor_shape_value", "tensor_spec_value", b"tensor_spec_value", "tuple_value", b"tuple_value", "type_spec_value", b"type_spec_value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["bool_value", b"bool_value", "bounded_tensor_spec_value", b"bounded_tensor_spec_value", "dict_value", b"dict_value", "float64_value", b"float64_value", "int64_value", b"int64_value", "kind", b"kind", "list_value", b"list_value", "named_tuple_value", b"named_tuple_value", "none_value", b"none_value", "string_value", b"string_value", "tensor_dtype_value", b"tensor_dtype_value", "tensor_shape_value", b"tensor_shape_value", "tensor_spec_value", b"tensor_spec_value", "tuple_value", b"tuple_value", "type_spec_value", b"type_spec_value"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["kind", b"kind"]) -> typing.Literal["none_value", "float64_value", "int64_value", "string_value", "bool_value", "tensor_shape_value", "tensor_dtype_value", "tensor_spec_value", "type_spec_value", "bounded_tensor_spec_value", "list_value", "tuple_value", "dict_value", "named_tuple_value"] | None: ...
def HasField(self, field_name: typing.Literal["bool_value", b"bool_value", "bounded_tensor_spec_value", b"bounded_tensor_spec_value", "dict_value", b"dict_value", "float64_value", b"float64_value", "int64_value", b"int64_value", "kind", b"kind", "list_value", b"list_value", "named_tuple_value", b"named_tuple_value", "none_value", b"none_value", "numpy_value", b"numpy_value", "string_value", b"string_value", "tensor_dtype_value", b"tensor_dtype_value", "tensor_shape_value", b"tensor_shape_value", "tensor_spec_value", b"tensor_spec_value", "tensor_value", b"tensor_value", "tuple_value", b"tuple_value", "type_spec_value", b"type_spec_value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["bool_value", b"bool_value", "bounded_tensor_spec_value", b"bounded_tensor_spec_value", "dict_value", b"dict_value", "float64_value", b"float64_value", "int64_value", b"int64_value", "kind", b"kind", "list_value", b"list_value", "named_tuple_value", b"named_tuple_value", "none_value", b"none_value", "numpy_value", b"numpy_value", "string_value", b"string_value", "tensor_dtype_value", b"tensor_dtype_value", "tensor_shape_value", b"tensor_shape_value", "tensor_spec_value", b"tensor_spec_value", "tensor_value", b"tensor_value", "tuple_value", b"tuple_value", "type_spec_value", b"type_spec_value"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["kind", b"kind"]) -> typing.Literal["none_value", "float64_value", "int64_value", "string_value", "bool_value", "tensor_shape_value", "tensor_dtype_value", "tensor_spec_value", "type_spec_value", "bounded_tensor_spec_value", "list_value", "tuple_value", "dict_value", "named_tuple_value", "tensor_value", "numpy_value"] | None: ...
global___StructuredValue = StructuredValue

View File

@@ -36,6 +36,7 @@ class ServerDef(google.protobuf.message.Message):
CLUSTER_FIELD_NUMBER: builtins.int
JOB_NAME_FIELD_NUMBER: builtins.int
REPLICA_FIELD_NUMBER: builtins.int
TASK_INDEX_FIELD_NUMBER: builtins.int
DEFAULT_SESSION_CONFIG_FIELD_NUMBER: builtins.int
PROTOCOL_FIELD_NUMBER: builtins.int
@@ -47,6 +48,8 @@ class ServerDef(google.protobuf.message.Message):
NOTE(mrry): The `cluster` field must contain a `JobDef` with a `name` field
that matches this name.
"""
replica: builtins.int
"""Replica this server manages."""
task_index: builtins.int
"""The task index of this server in its job.
@@ -79,6 +82,7 @@ class ServerDef(google.protobuf.message.Message):
*,
cluster: tensorflow.core.protobuf.cluster_pb2.ClusterDef | None = ...,
job_name: builtins.str | None = ...,
replica: builtins.int | None = ...,
task_index: builtins.int | None = ...,
default_session_config: tensorflow.core.protobuf.config_pb2.ConfigProto | None = ...,
protocol: builtins.str | None = ...,
@@ -86,6 +90,6 @@ class ServerDef(google.protobuf.message.Message):
cluster_device_filters: tensorflow.core.protobuf.device_filters_pb2.ClusterDeviceFilters | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["cluster", b"cluster", "cluster_device_filters", b"cluster_device_filters", "default_session_config", b"default_session_config"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["cluster", b"cluster", "cluster_device_filters", b"cluster_device_filters", "default_session_config", b"default_session_config", "job_name", b"job_name", "port", b"port", "protocol", b"protocol", "task_index", b"task_index"]) -> None: ...
def ClearField(self, field_name: typing.Literal["cluster", b"cluster", "cluster_device_filters", b"cluster_device_filters", "default_session_config", b"default_session_config", "job_name", b"job_name", "port", b"port", "protocol", b"protocol", "replica", b"replica", "task_index", b"task_index"]) -> None: ...
global___ServerDef = ServerDef

View File

@@ -392,6 +392,34 @@ class MomentumParameters(google.protobuf.message.Message):
global___MomentumParameters = MomentumParameters
@typing.final
class LionParameters(google.protobuf.message.Message):
"""https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Lion
momenta(new) = beta2 * momenta(old) + (1 - beta2) * grad
momenta_t = beta1 * momenta(old) + (1 - beta1) * grad
var(new) = var(old) - lr * sign(momenta_t)
Algorithm described in https://arxiv.org/abs/2302.06675.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
BETA1_FIELD_NUMBER: builtins.int
BETA2_FIELD_NUMBER: builtins.int
USE_NON_LAZY_LION_FIELD_NUMBER: builtins.int
beta1: builtins.float
beta2: builtins.float
use_non_lazy_lion: builtins.bool
def __init__(
self,
*,
beta1: builtins.float | None = ...,
beta2: builtins.float | None = ...,
use_non_lazy_lion: builtins.bool | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["beta1", b"beta1", "beta2", b"beta2", "use_non_lazy_lion", b"use_non_lazy_lion"]) -> None: ...
global___LionParameters = LionParameters
@typing.final
class RmsPropParameters(google.protobuf.message.Message):
"""https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop
@@ -898,6 +926,7 @@ class OptimizationParameters(google.protobuf.message.Message):
FTRL_FIELD_NUMBER: builtins.int
ADAM_FIELD_NUMBER: builtins.int
MOMENTUM_FIELD_NUMBER: builtins.int
LION_FIELD_NUMBER: builtins.int
RMS_PROP_FIELD_NUMBER: builtins.int
CENTERED_RMS_PROP_FIELD_NUMBER: builtins.int
MDL_ADAGRAD_LIGHT_FIELD_NUMBER: builtins.int
@@ -975,6 +1004,8 @@ class OptimizationParameters(google.protobuf.message.Message):
@property
def momentum(self) -> global___MomentumParameters: ...
@property
def lion(self) -> global___LionParameters: ...
@property
def rms_prop(self) -> global___RmsPropParameters: ...
@property
def centered_rms_prop(self) -> global___CenteredRmsPropParameters: ...
@@ -1013,6 +1044,7 @@ class OptimizationParameters(google.protobuf.message.Message):
ftrl: global___FtrlParameters | None = ...,
adam: global___AdamParameters | None = ...,
momentum: global___MomentumParameters | None = ...,
lion: global___LionParameters | None = ...,
rms_prop: global___RmsPropParameters | None = ...,
centered_rms_prop: global___CenteredRmsPropParameters | None = ...,
mdl_adagrad_light: global___MdlAdagradLightParameters | None = ...,
@@ -1024,9 +1056,9 @@ class OptimizationParameters(google.protobuf.message.Message):
user_defined_program: global___UserDefinedProgramParameters | None = ...,
assign: global___AssignParameters | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["adadelta", b"adadelta", "adagrad", b"adagrad", "adagrad_momentum", b"adagrad_momentum", "adam", b"adam", "assign", b"assign", "bounded_adagrad", b"bounded_adagrad", "centered_rms_prop", b"centered_rms_prop", "clipping_limits", b"clipping_limits", "frequency_estimator", b"frequency_estimator", "ftrl", b"ftrl", "gradient_clipping_limits", b"gradient_clipping_limits", "hot_id_replication_configuration", b"hot_id_replication_configuration", "learning_rate", b"learning_rate", "mdl_adagrad_light", b"mdl_adagrad_light", "momentum", b"momentum", "online_yogi", b"online_yogi", "parameters", b"parameters", "proximal_adagrad", b"proximal_adagrad", "proximal_yogi", b"proximal_yogi", "rms_prop", b"rms_prop", "simulated_quantization", b"simulated_quantization", "stochastic_gradient_descent", b"stochastic_gradient_descent", "user_defined_program", b"user_defined_program"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["adadelta", b"adadelta", "adagrad", b"adagrad", "adagrad_momentum", b"adagrad_momentum", "adam", b"adam", "assign", b"assign", "bounded_adagrad", b"bounded_adagrad", "centered_rms_prop", b"centered_rms_prop", "clipping_limits", b"clipping_limits", "frequency_estimator", b"frequency_estimator", "ftrl", b"ftrl", "gradient_accumulation_status", b"gradient_accumulation_status", "gradient_clipping_limits", b"gradient_clipping_limits", "hot_id_replication_configuration", b"hot_id_replication_configuration", "learning_rate", b"learning_rate", "low_dimensional_packing_status", b"low_dimensional_packing_status", "mdl_adagrad_light", b"mdl_adagrad_light", "momentum", b"momentum", "multiply_weight_decay_factor_by_learning_rate", b"multiply_weight_decay_factor_by_learning_rate", "online_yogi", b"online_yogi", "parameters", b"parameters", "proximal_adagrad", b"proximal_adagrad", "proximal_yogi", b"proximal_yogi", "rms_prop", b"rms_prop", "simulated_quantization", b"simulated_quantization", "stochastic_gradient_descent", b"stochastic_gradient_descent", "user_defined_program", b"user_defined_program", "weight_decay_factor", b"weight_decay_factor"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["parameters", b"parameters"]) -> typing.Literal["adagrad", "adagrad_momentum", "bounded_adagrad", "stochastic_gradient_descent", "ftrl", "adam", "momentum", "rms_prop", "centered_rms_prop", "mdl_adagrad_light", "adadelta", "proximal_adagrad", "online_yogi", "proximal_yogi", "frequency_estimator", "user_defined_program", "assign"] | None: ...
def HasField(self, field_name: typing.Literal["adadelta", b"adadelta", "adagrad", b"adagrad", "adagrad_momentum", b"adagrad_momentum", "adam", b"adam", "assign", b"assign", "bounded_adagrad", b"bounded_adagrad", "centered_rms_prop", b"centered_rms_prop", "clipping_limits", b"clipping_limits", "frequency_estimator", b"frequency_estimator", "ftrl", b"ftrl", "gradient_clipping_limits", b"gradient_clipping_limits", "hot_id_replication_configuration", b"hot_id_replication_configuration", "learning_rate", b"learning_rate", "lion", b"lion", "mdl_adagrad_light", b"mdl_adagrad_light", "momentum", b"momentum", "online_yogi", b"online_yogi", "parameters", b"parameters", "proximal_adagrad", b"proximal_adagrad", "proximal_yogi", b"proximal_yogi", "rms_prop", b"rms_prop", "simulated_quantization", b"simulated_quantization", "stochastic_gradient_descent", b"stochastic_gradient_descent", "user_defined_program", b"user_defined_program"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["adadelta", b"adadelta", "adagrad", b"adagrad", "adagrad_momentum", b"adagrad_momentum", "adam", b"adam", "assign", b"assign", "bounded_adagrad", b"bounded_adagrad", "centered_rms_prop", b"centered_rms_prop", "clipping_limits", b"clipping_limits", "frequency_estimator", b"frequency_estimator", "ftrl", b"ftrl", "gradient_accumulation_status", b"gradient_accumulation_status", "gradient_clipping_limits", b"gradient_clipping_limits", "hot_id_replication_configuration", b"hot_id_replication_configuration", "learning_rate", b"learning_rate", "lion", b"lion", "low_dimensional_packing_status", b"low_dimensional_packing_status", "mdl_adagrad_light", b"mdl_adagrad_light", "momentum", b"momentum", "multiply_weight_decay_factor_by_learning_rate", b"multiply_weight_decay_factor_by_learning_rate", "online_yogi", b"online_yogi", "parameters", b"parameters", "proximal_adagrad", b"proximal_adagrad", "proximal_yogi", b"proximal_yogi", "rms_prop", b"rms_prop", "simulated_quantization", b"simulated_quantization", "stochastic_gradient_descent", b"stochastic_gradient_descent", "user_defined_program", b"user_defined_program", "weight_decay_factor", b"weight_decay_factor"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["parameters", b"parameters"]) -> typing.Literal["adagrad", "adagrad_momentum", "bounded_adagrad", "stochastic_gradient_descent", "ftrl", "adam", "momentum", "lion", "rms_prop", "centered_rms_prop", "mdl_adagrad_light", "adadelta", "proximal_adagrad", "online_yogi", "proximal_yogi", "frequency_estimator", "user_defined_program", "assign"] | None: ...
global___OptimizationParameters = OptimizationParameters

View File

@@ -62,13 +62,17 @@ class TPUHardwareFeature(google.protobuf.message.Message):
"""
EMBEDDING_FEATURE_FIELD_NUMBER: builtins.int
NUM_EMBEDDING_DEVICES_PER_CHIP_FIELD_NUMBER: builtins.int
embedding_feature: global___TPUHardwareFeature.EmbeddingFeature.ValueType
num_embedding_devices_per_chip: builtins.int
"""Number of embedding accelerator devices per chip."""
def __init__(
self,
*,
embedding_feature: global___TPUHardwareFeature.EmbeddingFeature.ValueType | None = ...,
num_embedding_devices_per_chip: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["embedding_feature", b"embedding_feature"]) -> None: ...
def ClearField(self, field_name: typing.Literal["embedding_feature", b"embedding_feature", "num_embedding_devices_per_chip", b"num_embedding_devices_per_chip"]) -> None: ...
global___TPUHardwareFeature = TPUHardwareFeature

View File

@@ -0,0 +1,5 @@
from _typeshed import Incomplete
from .experimental.coordinator import RemoteValue as RemoteValue
def __getattr__(name: str) -> Incomplete: ...

View File

@@ -96,7 +96,9 @@ class RaggedFeature(NamedTuple):
dtype: DTypeLike
value_key: str | None = ...
partitions: tuple[RowSplits | RowLengths | RowStarts | RowLimits | ValueRowIds | UniformRowLength, ...] = ... # type: ignore[name-defined]
partitions: tuple[ # type: ignore[name-defined]
RowSplits | RowLengths | RowStarts | RowLimits | ValueRowIds | UniformRowLength, ...
] = ...
row_splits_dtype: DTypeLike = ...
validate: bool = ...

View File

@@ -11,9 +11,7 @@ _Activation: TypeAlias = str | None | Callable[[Tensor], Tensor] | dict[str, Any
# Ints are not allowed.
_ActivationInput: TypeAlias = Tensor | FloatDataSequence | FloatArray | np.number[Any] | float
def deserialize(
name: str, custom_objects: dict[str, Callable[..., Any]] | None = None, use_legacy_format: bool = False
) -> Callable[..., Any]: ...
def deserialize(config: dict[str, Any], custom_objects: dict[str, Callable[..., Any]] | None = None) -> Callable[..., Any]: ...
def elu(x: _ActivationInput, alpha: FloatTensorCompatible | FloatDataSequence = 1.0) -> Tensor: ...
def exponential(x: _ActivationInput) -> Tensor: ...
def gelu(x: _ActivationInput, approximate: bool = False) -> Tensor: ...
@@ -23,12 +21,12 @@ def linear(x: _ActivationInput) -> Tensor: ...
def mish(x: _ActivationInput) -> Tensor: ...
def relu(
x: _ActivationInput,
alpha: FloatTensorCompatible = 0.0,
negative_slope: FloatTensorCompatible = 0.0,
max_value: FloatTensorCompatible | FloatDataSequence | None = None,
threshold: FloatTensorCompatible | FloatDataSequence = 0.0,
) -> Tensor: ...
def selu(x: _ActivationInput) -> Tensor: ...
def serialize(activation: Callable[..., Any], use_legacy_format: bool = False) -> str | dict[str, Any]: ...
def serialize(activation: Callable[..., Any]) -> str | dict[str, Any]: ...
def sigmoid(x: _ActivationInput) -> Tensor: ...
def softmax(x: Tensor, axis: Integer = -1) -> Tensor: ...
def softplus(x: _ActivationInput) -> Tensor: ...

View File

@@ -1,4 +1,4 @@
from collections.abc import Callable, Iterable, Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal
from typing_extensions import TypeAlias
@@ -6,37 +6,33 @@ import tensorflow as tf
from requests.api import _HeadersMapping
from tensorflow.keras import Model
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
from tensorflow.saved_model import SaveOptions
from tensorflow.train import CheckpointOptions
_Logs: TypeAlias = Mapping[str, Any] | None | Any
class Callback:
model: Model[Any, Any]
params: dict[str, Any]
def set_model(self, model: Model[Any, Any]) -> None: ...
def set_params(self, params: dict[str, Any]) -> None: ...
def set_model(self, model: Model[Any, Any]) -> None: ...
@property
def model(self) -> Model[Any, Any]: ...
def on_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
def on_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
def on_epoch_begin(self, epoch: int, logs: _Logs = None) -> None: ...
def on_epoch_end(self, epoch: int, logs: _Logs = None) -> None: ...
def on_predict_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
def on_predict_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
def on_predict_begin(self, logs: _Logs = None) -> None: ...
def on_predict_end(self, logs: _Logs = None) -> None: ...
def on_test_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
def on_test_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
def on_test_begin(self, logs: _Logs = None) -> None: ...
def on_test_end(self, logs: _Logs = None) -> None: ...
def on_train_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
def on_train_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
def on_test_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
def on_test_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
def on_predict_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
def on_predict_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
def on_train_begin(self, logs: _Logs = None) -> None: ...
def on_train_end(self, logs: _Logs = None) -> None: ...
def on_test_begin(self, logs: _Logs = None) -> None: ...
def on_test_end(self, logs: _Logs = None) -> None: ...
def on_predict_begin(self, logs: _Logs = None) -> None: ...
def on_predict_end(self, logs: _Logs = None) -> None: ...
# A CallbackList has exact same api as a callback, but does not actually subclass it.
class CallbackList:
model: Model[Any, Any]
params: dict[str, Any]
class CallbackList(Callback):
def __init__(
self,
callbacks: Sequence[Callback] | None = None,
@@ -46,32 +42,28 @@ class CallbackList:
model: Model[Any, Any] | None = None,
**params: Any,
) -> None: ...
def set_model(self, model: Model[Any, Any]) -> None: ...
def append(self, callback: Callback) -> None: ...
def set_params(self, params: dict[str, Any]) -> None: ...
def on_batch_begin(self, batch: int, logs: _Logs | None = None) -> None: ...
def on_batch_end(self, batch: int, logs: _Logs | None = None) -> None: ...
def on_epoch_begin(self, epoch: int, logs: _Logs | None = None) -> None: ...
def on_epoch_end(self, epoch: int, logs: _Logs | None = None) -> None: ...
def on_predict_batch_begin(self, batch: int, logs: _Logs | None = None) -> None: ...
def on_predict_batch_end(self, batch: int, logs: _Logs | None = None) -> None: ...
def on_predict_begin(self, logs: _Logs | None = None) -> None: ...
def on_predict_end(self, logs: _Logs | None = None) -> None: ...
def on_test_batch_begin(self, batch: int, logs: _Logs | None = None) -> None: ...
def on_test_batch_end(self, batch: int, logs: _Logs | None = None) -> None: ...
def on_test_begin(self, logs: _Logs | None = None) -> None: ...
def on_test_end(self, logs: _Logs | None = None) -> None: ...
def on_train_batch_begin(self, batch: int, logs: _Logs | None = None) -> None: ...
def on_train_batch_end(self, batch: int, logs: _Logs | None = None) -> None: ...
def on_train_begin(self, logs: _Logs | None = None) -> None: ...
def on_train_end(self, logs: _Logs | None = None) -> None: ...
def set_model(self, model: Model[Any, Any]) -> None: ...
def on_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
def on_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
def on_epoch_begin(self, epoch: int, logs: _Logs = None) -> None: ...
def on_epoch_end(self, epoch: int, logs: _Logs = None) -> None: ...
def on_train_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
def on_train_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
def on_test_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
def on_test_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
def on_predict_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
def on_predict_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
def on_train_begin(self, logs: _Logs = None) -> None: ...
def on_train_end(self, logs: _Logs = None) -> None: ...
def on_test_begin(self, logs: _Logs = None) -> None: ...
def on_test_end(self, logs: _Logs = None) -> None: ...
def on_predict_begin(self, logs: _Logs = None) -> None: ...
def on_predict_end(self, logs: _Logs = None) -> None: ...
class BackupAndRestore(Callback):
def __init__(
self, backup_dir: str, save_freq: str = "epoch", delete_checkpoint: bool = True, save_before_preemption: bool = False
) -> None: ...
class BaseLogger(Callback):
def __init__(self, stateful_metrics: Iterable[str] | None = None) -> None: ...
def __init__(self, backup_dir: str, save_freq: str = "epoch", delete_checkpoint: bool = True) -> None: ...
class CSVLogger(Callback):
def __init__(self, filename: str, separator: str = ",", append: bool = False) -> None: ...
@@ -98,10 +90,10 @@ class LambdaCallback(Callback):
self,
on_epoch_begin: Callable[[int, _Logs], object] | None = None,
on_epoch_end: Callable[[int, _Logs], object] | None = None,
on_batch_begin: Callable[[int, _Logs], object] | None = None,
on_batch_end: Callable[[int, _Logs], object] | None = None,
on_train_begin: Callable[[_Logs], object] | None = None,
on_train_end: Callable[[_Logs], object] | None = None,
on_train_batch_begin: Callable[[int, _Logs], object] | None = None,
on_train_batch_end: Callable[[int, _Logs], object] | None = None,
**kwargs: Any,
) -> None: ...
@@ -115,7 +107,6 @@ class LearningRateScheduler(Callback):
class ModelCheckpoint(Callback):
monitor_op: Any
filepath: str
_options: CheckpointOptions | SaveOptions | None
def __init__(
self,
filepath: str,
@@ -125,16 +116,13 @@ class ModelCheckpoint(Callback):
save_weights_only: bool = False,
mode: Literal["auto", "min", "max"] = "auto",
save_freq: str | int = "epoch",
options: CheckpointOptions | SaveOptions | None = None,
initial_value_threshold: float | None = None,
) -> None: ...
def _save_model(self, epoch: int, batch: int | None, logs: _Logs) -> None: ...
class ProgbarLogger(Callback):
use_steps: bool
def __init__(
self, count_mode: Literal["steps", "samples"] = "samples", stateful_metrics: Iterable[str] | None = None
) -> None: ...
def __init__(self) -> None: ...
class ReduceLROnPlateau(Callback):
def __init__(
@@ -146,7 +134,7 @@ class ReduceLROnPlateau(Callback):
mode: Literal["auto", "min", "max"] = "auto",
min_delta: float = 1e-4,
cooldown: int = 0,
min_lr: float = 0,
min_lr: float = 0.0,
**kwargs,
) -> None: ...
def in_cooldown(self) -> bool: ...
@@ -176,7 +164,6 @@ class TensorBoard(Callback):
profile_batch: int | tuple[int, int] = 0,
embeddings_freq: int = 0,
embeddings_metadata: dict[str, None] | None = None,
**kwargs: Any,
) -> None: ...
def _write_keras_model_train_graph(self) -> None: ...
def _stop_trace(self, batch: int | None = None) -> None: ...

View File

@@ -13,7 +13,7 @@ class Initializer:
def from_config(cls, config: dict[str, Any]) -> Self: ...
class Constant(Initializer):
def __init__(self, value: TensorCompatible = 0) -> None: ...
def __init__(self, value: TensorCompatible = 0.0) -> None: ...
class GlorotNormal(Initializer):
def __init__(self, seed: int | None = None) -> None: ...

View File

@@ -1,17 +1,15 @@
from _typeshed import Incomplete
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Any, Generic, Literal, TypeVar, overload
from collections.abc import Callable, Iterable, Sequence
from typing import Any, Generic, Literal, TypeVar, overload, type_check_only
from typing_extensions import Self, TypeAlias
import tensorflow as tf
from tensorflow import Tensor, Variable, VariableAggregation, VariableSynchronization
from tensorflow._aliases import AnyArray, DTypeLike, TensorCompatible, TensorLike
from tensorflow import Tensor, Variable
from tensorflow._aliases import AnyArray, DataSequence, DTypeLike, Float, TensorCompatible, TensorLike
from tensorflow.keras.activations import _Activation
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.initializers import _Initializer
from tensorflow.keras.layers.preprocessing import IntegerLookup as IntegerLookup, StringLookup as StringLookup
from tensorflow.keras.regularizers import Regularizer, _Regularizer
from tensorflow.python.feature_column.feature_column_v2 import DenseColumn, SequenceDenseColumn
_InputT = TypeVar("_InputT", contravariant=True)
_OutputT = TypeVar("_OutputT", covariant=True)
@@ -51,7 +49,16 @@ class Layer(tf.Module, Generic[_InputT, _OutputT]):
@trainable.setter
def trainable(self, value: bool) -> None: ...
def __init__(
self, trainable: bool = True, name: str | None = None, dtype: DTypeLike | None = None, dynamic: bool = False
self,
*,
activity_regularizer: _Regularizer = None,
trainable: bool = True,
dtype: DTypeLike | None = None,
autocast: bool = True,
name: str | None = None,
# **kwargs
input_dim: int | None = None,
input_shape: Any = None,
) -> None: ...
# *args/**kwargs are allowed, but have obscure footguns and tensorflow documentation discourages their usage.
@@ -69,18 +76,17 @@ class Layer(tf.Module, Generic[_InputT, _OutputT]):
def compute_output_shape(self, input_shape: Any, /) -> Any: ...
def add_weight(
self,
name: str | None = None,
shape: Iterable[int | None] | None = None,
dtype: DTypeLike | None = None,
initializer: _Initializer | None = None,
dtype: DTypeLike | None = None,
trainable: bool = True,
autocast: bool = True,
regularizer: _Regularizer = None,
trainable: bool | None = None,
constraint: _Constraint = None,
use_resource: bool | None = None,
synchronization: VariableSynchronization = ...,
aggregation: VariableAggregation = ...,
aggregation: Literal["mean", "sum", "only_first_replica"] = "mean",
name: str | None = None,
) -> tf.Variable: ...
def add_loss(self, losses: tf.Tensor | Sequence[tf.Tensor] | Callable[[], tf.Tensor]) -> None: ...
def add_loss(self, loss: tf.Tensor | Sequence[tf.Tensor] | Callable[[], tf.Tensor]) -> None: ...
def count_params(self) -> int: ...
@property
def trainable_variables(self) -> list[Variable]: ...
@@ -112,6 +118,89 @@ _LayerDtype: TypeAlias = DTypeLike | dict[str, Any] | Any
_Constraint: TypeAlias = str | dict[str, Any] | Constraint | None
# IndexLookup is not exported by Keras
@type_check_only
class _IndexLookup(Layer[tf.Tensor, tf.Tensor]):
def __init__(
self,
max_tokens: int | None,
num_oov_indices: int,
mask_token: str | None,
oov_token: str,
vocabulary_dtype: Literal["int64", "string"],
vocabulary: str | None | TensorCompatible = None,
idf_weights: TensorCompatible | None = None,
invert: bool = False,
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int",
sparse: bool = False,
pad_to_max_tokens: bool = False,
name: str | None = None,
*,
# **kwargs
vocabulary_size: int | None = None,
has_input_vocabulary: bool = ...,
trainable: bool | None = None,
dtype: _LayerDtype | None = None,
# **kwargs passed to Layer
activity_regularizer: _Regularizer = None,
autocast: bool = True,
) -> None: ...
def compute_output_signature(self, input_spec) -> tf.TensorSpec: ...
def get_vocabulary(self, include_special_tokens: bool = True) -> list[Incomplete]: ...
def vocabulary_size(self) -> int: ...
class StringLookup(_IndexLookup):
def __init__(
self,
max_tokens: int | None = None,
num_oov_indices: int = 1,
mask_token: str | None = None,
oov_token: str = "[UNK]",
vocabulary: str | None | TensorCompatible = None,
idf_weights: TensorCompatible | None = None,
invert: bool = False,
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int",
pad_to_max_tokens: bool = False,
sparse: bool = False,
encoding: str = "utf-8",
name: str | None = None,
*,
# **kwargs passed to IndexLookup
vocabulary_size: int | None = None,
has_input_vocabulary: bool = ...,
trainable: bool | None = None,
dtype: _LayerDtype | None = None,
activity_regularizer: _Regularizer = None,
autocast: bool = True,
) -> None: ...
def adapt(self, data: tf.data.Dataset[TensorLike] | AnyArray | DataSequence, steps: Float | None = None) -> None: ...
class IntegerLookup(_IndexLookup):
def __init__(
self,
max_tokens: int | None = None,
num_oov_indices: int = 1,
mask_token: int | None = None,
oov_token: int = -1,
vocabulary: str | None | TensorCompatible = None,
vocabulary_dtype: Literal["int64"] = "int64",
idf_weights: TensorCompatible | None = None,
invert: bool = False,
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int",
sparse: bool = False,
pad_to_max_tokens: bool = False,
name: str | None = None,
*,
# **kwargs passed to IndexLookup
vocabulary_size: int | None = None,
has_input_vocabulary: bool = ...,
trainable: bool | None = None,
dtype: _LayerDtype | None = None,
activity_regularizer: _Regularizer = None,
autocast: bool = True,
) -> None: ...
def adapt(self, data: tf.data.Dataset[TensorLike] | AnyArray | DataSequence, steps: Float | None = None) -> None: ...
# Layer's compute_output_shape commonly have instance as first argument name instead of self.
# This is an artifact of actual implementation commonly uses a decorator to define it.
# Layer.build has same weirdness sometimes. For both marked as positional only.
@@ -128,9 +217,12 @@ class Dense(Layer[tf.Tensor, tf.Tensor]):
activity_regularizer: _Regularizer = None,
kernel_constraint: _Constraint = None,
bias_constraint: _Constraint = None,
lora_rank: int | None = None,
*,
# **kwargs passed to Layer
trainable: bool = True,
dtype: _LayerDtype | None = None,
dynamic: bool = False,
autocast: bool = True,
name: str | None = None,
) -> None: ...
@@ -150,9 +242,13 @@ class BatchNormalization(Layer[tf.Tensor, tf.Tensor]):
gamma_regularizer: _Regularizer = None,
beta_constraint: _Constraint = None,
gamma_constraint: _Constraint = None,
synchronized: bool = False,
*,
# **kwargs passed to Layer
activity_regularizer: _Regularizer = None,
trainable: bool = True,
dtype: _LayerDtype | None = None,
dynamic: bool = False,
autocast: bool = True,
name: str | None = None,
) -> None: ...
@@ -162,9 +258,12 @@ class ReLU(Layer[tf.Tensor, tf.Tensor]):
max_value: float | None = None,
negative_slope: float | None = 0.0,
threshold: float | None = 0.0,
*,
# **kwargs passed to Layer
activity_regularizer: _Regularizer = None,
trainable: bool = True,
dtype: _LayerDtype | None = None,
dynamic: bool = False,
autocast: bool = True,
name: str | None = None,
) -> None: ...
@@ -174,9 +273,12 @@ class Dropout(Layer[tf.Tensor, tf.Tensor]):
rate: float,
noise_shape: TensorCompatible | Sequence[int | None] | None = None,
seed: int | None = None,
*,
# **kwargs passed to Layer
activity_regularizer: _Regularizer = None,
trainable: bool = True,
dtype: _LayerDtype | None = None,
dynamic: bool = False,
autocast: bool = True,
name: str | None = None,
) -> None: ...
@@ -189,10 +291,14 @@ class Embedding(Layer[tf.Tensor, tf.Tensor]):
embeddings_regularizer: _Regularizer = None,
embeddings_constraint: _Constraint = None,
mask_zero: bool = False,
lora_rank: int | None = None,
*,
input_length: int | None = None,
# **kwargs passed to Layer
activity_regularizer: _Regularizer = None,
trainable: bool = True,
dtype: _LayerDtype | None = None,
dynamic: bool = False,
autocast: bool = True,
name: str | None = None,
) -> None: ...
@@ -215,15 +321,26 @@ class Conv2D(Layer[tf.Tensor, tf.Tensor]):
activity_regularizer: _Regularizer = None,
kernel_constraint: _Constraint = None,
bias_constraint: _Constraint = None,
*,
# **kwargs passed to Layer
trainable: bool = True,
dtype: _LayerDtype | None = None,
dynamic: bool = False,
autocast: bool = True,
name: str | None = None,
) -> None: ...
Convolution2D = Conv2D
class Identity(Layer[tf.Tensor, tf.Tensor]):
def __init__(
self, trainable: bool = True, dtype: _LayerDtype = None, dynamic: bool = False, name: str | None = None
self,
*,
# **kwargs passed to Layer
activity_regularizer: _Regularizer = None,
trainable: bool = True,
dtype: _LayerDtype | None = None,
autocast: bool = True,
name: str | None = None,
) -> None: ...
class LayerNormalization(Layer[tf.Tensor, tf.Tensor]):
@@ -233,25 +350,19 @@ class LayerNormalization(Layer[tf.Tensor, tf.Tensor]):
epsilon: float = 0.001,
center: bool = True,
scale: bool = True,
rms_scaling: bool = False,
beta_initializer: _Initializer = "zeros",
gamma_initializer: _Initializer = "ones",
beta_regularizer: _Regularizer = None,
gamma_regularizer: _Regularizer = None,
beta_constraint: _Constraint = None,
gamma_constraint: _Constraint = None,
*,
# **kwargs passed to Layer
activity_regularizer: _Regularizer = None,
trainable: bool = True,
dtype: _LayerDtype | None = None,
dynamic: bool = False,
name: str | None = None,
) -> None: ...
class DenseFeatures(Layer[Mapping[str, TensorLike], tf.Tensor]):
def __init__(
self,
feature_columns: Sequence[DenseColumn | SequenceDenseColumn],
trainable: bool = True,
dtype: _LayerDtype = None,
dynamic: bool = False,
autocast: bool = True,
name: str | None = None,
) -> None: ...
@@ -265,16 +376,18 @@ class MultiHeadAttention(Layer[Any, tf.Tensor]):
use_bias: bool = True,
output_shape: tuple[int, ...] | None = None,
attention_axes: tuple[int, ...] | None = None,
kernel_initialize: _Initializer = "glorot_uniform",
kernel_initializer: _Initializer = "glorot_uniform",
bias_initializer: _Initializer = "zeros",
kernel_regularizer: Regularizer | None = None,
bias_regularizer: _Regularizer | None = None,
activity_regularizer: _Regularizer | None = None,
kernel_constraint: _Constraint | None = None,
bias_constraint: _Constraint | None = None,
*,
# **kwargs passed to Layer
trainable: bool = True,
dtype: _LayerDtype | None = None,
dynamic: bool = False,
autocast: bool = True,
name: str | None = None,
) -> None: ...
# @override
@@ -317,9 +430,12 @@ class GaussianDropout(Layer[tf.Tensor, tf.Tensor]):
self,
rate: float,
seed: int | None = None,
*,
# **kwargs passed to Layer
activity_regularizer: _Regularizer = None,
trainable: bool = True,
dtype: _LayerDtype = None,
dynamic: bool = False,
dtype: _LayerDtype | None = None,
autocast: bool = True,
name: str | None = None,
) -> None: ...

View File

@@ -1,27 +0,0 @@
import abc
from typing import overload
import tensorflow as tf
from tensorflow._aliases import AnyArray, DataSequence, Float, Integer, TensorCompatible, TensorLike
from tensorflow.keras.layers import Layer
class PreprocessingLayer(Layer[TensorLike, TensorLike], metaclass=abc.ABCMeta):
@property
def is_adapted(self) -> bool: ...
@overload # type: ignore
def __call__(self, inputs: tf.Tensor, *, training: bool = False, mask: TensorCompatible | None = None) -> tf.Tensor: ...
@overload
def __call__(
self, inputs: tf.SparseTensor, *, training: bool = False, mask: TensorCompatible | None = None
) -> tf.SparseTensor: ...
@overload
def __call__(
self, inputs: tf.RaggedTensor, *, training: bool = False, mask: TensorCompatible | None = None
) -> tf.RaggedTensor: ...
def adapt(
self,
data: tf.data.Dataset[TensorLike] | AnyArray | DataSequence,
batch_size: Integer | None = None,
steps: Float | None = None,
) -> None: ...
def compile(self, run_eagerly: bool | None = None, steps_per_execution: Integer | None = None) -> None: ...

View File

@@ -1,36 +0,0 @@
from typing import Literal
from tensorflow._aliases import TensorCompatible
from tensorflow.keras.layers.preprocessing.index_lookup import _IndexLookup
class StringLookup(_IndexLookup):
def __init__(
self,
max_tokens: int | None = None,
num_oov_indices: int = 1,
mask_token: str | None = None,
oov_token: str = "[UNK]",
vocabulary: str | None | TensorCompatible = None,
idf_weights: TensorCompatible | None = None,
encoding: str = "utf-8",
invert: bool = False,
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int",
sparse: bool = False,
pad_to_max_tokens: bool = False,
) -> None: ...
class IntegerLookup(_IndexLookup):
def __init__(
self,
max_tokens: int | None = None,
num_oov_indices: int = 1,
mask_token: int | None = None,
oov_token: int = -1,
vocabulary: str | None | TensorCompatible = None,
vocabulary_dtype: Literal["int64", "int32"] = "int64",
idf_weights: TensorCompatible | None = None,
invert: bool = False,
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int",
sparse: bool = False,
pad_to_max_tokens: bool = False,
) -> None: ...

View File

@@ -1,9 +0,0 @@
from _typeshed import Incomplete
import tensorflow as tf
from tensorflow.keras.layers.experimental.preprocessing import PreprocessingLayer
class _IndexLookup(PreprocessingLayer):
def compute_output_signature(self, input_spec) -> tf.TensorSpec: ...
def get_vocabulary(self, include_special_tokens: bool = True) -> list[Incomplete]: ...
def vocabulary_size(self) -> int: ...

View File

@@ -14,7 +14,9 @@ from tensorflow.keras.metrics import (
class Loss(ABC):
reduction: _ReductionValues
name: str | None
def __init__(self, reduction: _ReductionValues = "auto", name: str | None = None) -> None: ...
def __init__(
self, name: str | None = None, reduction: _ReductionValues = "sum_over_batch_size", dtype: Incomplete | None = None
) -> None: ...
@abstractmethod
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
@classmethod
@@ -30,7 +32,7 @@ class BinaryCrossentropy(Loss):
from_logits: bool = False,
label_smoothing: float = 0.0,
axis: int = -1,
reduction: _ReductionValues = ...,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "binary_crossentropy",
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
@@ -44,7 +46,7 @@ class BinaryFocalCrossentropy(Loss):
from_logits: bool = False,
label_smoothing: float = 0.0,
axis: int = -1,
reduction: _ReductionValues = ...,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "binary_focal_crossentropy",
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
@@ -55,53 +57,61 @@ class CategoricalCrossentropy(Loss):
from_logits: bool = False,
label_smoothing: float = 0.0,
axis: int = -1,
reduction: _ReductionValues = ...,
reduction: _ReductionValues = "sum_over_batch_size",
name: str | None = "categorical_crossentropy",
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class CategoricalHinge(Loss):
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "categorical_hinge") -> None: ...
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "categorical_hinge") -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class CosineSimilarity(Loss):
def __init__(self, axis: int = -1, reduction: _ReductionValues = ..., name: str | None = "cosine_similarity") -> None: ...
def __init__(
self, axis: int = -1, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "cosine_similarity"
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class Hinge(Loss):
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "hinge") -> None: ...
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "hinge") -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class Huber(Loss):
def __init__(self, delta: float = 1.0, reduction: _ReductionValues = ..., name: str | None = "huber_loss") -> None: ...
def __init__(
self, delta: float = 1.0, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "huber_loss"
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class KLDivergence(Loss):
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "kl_divergence") -> None: ...
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "kl_divergence") -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class LogCosh(Loss):
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "log_cosh") -> None: ...
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "log_cosh") -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class MeanAbsoluteError(Loss):
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_absolute_error") -> None: ...
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_absolute_error") -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class MeanAbsolutePercentageError(Loss):
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_absolute_percentage_error") -> None: ...
def __init__(
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_absolute_percentage_error"
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class MeanSquaredError(Loss):
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_squared_error") -> None: ...
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_squared_error") -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class MeanSquaredLogarithmicError(Loss):
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_squared_logarithmic_error") -> None: ...
def __init__(
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_squared_logarithmic_error"
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class Poisson(Loss):
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "poisson") -> None: ...
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "poisson") -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class SparseCategoricalCrossentropy(Loss):
@@ -109,13 +119,13 @@ class SparseCategoricalCrossentropy(Loss):
self,
from_logits: bool = False,
ignore_class: int | None = None,
reduction: _ReductionValues = ...,
reduction: _ReductionValues = "sum_over_batch_size",
name: str = "sparse_categorical_crossentropy",
) -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class SquaredHinge(Loss):
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "squared_hinge") -> None: ...
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "squared_hinge") -> None: ...
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
class Reduction:
@@ -133,10 +143,8 @@ _ReductionValues: TypeAlias = Literal["auto", "none", "sum", "sum_over_batch_siz
def categorical_hinge(y_true: TensorCompatible, y_pred: TensorCompatible) -> Tensor: ...
def huber(y_true: TensorCompatible, y_pred: TensorCompatible, delta: float = 1.0) -> Tensor: ...
def log_cosh(y_true: TensorCompatible, y_pred: TensorCompatible) -> Tensor: ...
def deserialize(
name: str | dict[str, Any], custom_objects: dict[str, Any] | None = None, use_legacy_format: bool = False
) -> Loss: ...
def serialize(loss: KerasSerializable, use_legacy_format: bool = False) -> dict[str, Any]: ...
def deserialize(name: str | dict[str, Any], custom_objects: dict[str, Any] | None = None) -> Loss: ...
def serialize(loss: KerasSerializable) -> dict[str, Any]: ...
_FuncT = TypeVar("_FuncT", bound=Callable[..., Any])

View File

@@ -12,9 +12,8 @@ from tensorflow.keras.initializers import _Initializer
_Output: TypeAlias = Tensor | dict[str, Tensor]
class Metric(tf.keras.layers.Layer[tf.Tensor, tf.Tensor], metaclass=ABCMeta):
def __init__(self, name: str | None = None, dtype: DTypeLike | None = None) -> None: ...
def __init__(self, dtype: DTypeLike | None = None, name: str | None = None) -> None: ...
def __new__(cls, *args: Any, **kwargs: Any) -> Self: ...
def merge_state(self, metrics: Iterable[Self]) -> list[Operation]: ...
def reset_state(self) -> None: ...
@abstractmethod
def update_state(
@@ -110,7 +109,7 @@ class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
self, k: int = 5, name: str | None = "sparse_top_k_categorical_accuracy", dtype: DTypeLike | None = None
) -> None: ...
def serialize(metric: KerasSerializable, use_legacy_format: bool = False) -> dict[str, Any]: ...
def serialize(metric: KerasSerializable) -> dict[str, Any]: ...
def binary_crossentropy(
y_true: TensorCompatible, y_pred: TensorCompatible, from_logits: bool = False, label_smoothing: float = 0.0, axis: int = -1
) -> Tensor: ...

View File

@@ -2,11 +2,10 @@ from _typeshed import Incomplete
from collections.abc import Callable, Container, Iterator
from pathlib import Path
from typing import Any, Literal
from typing_extensions import Self, TypeAlias
from typing_extensions import Self, TypeAlias, deprecated
import numpy as np
import numpy.typing as npt
import tensorflow
import tensorflow as tf
from tensorflow import Variable
from tensorflow._aliases import ContainerGeneric, ShapeLike, TensorCompatible
@@ -16,18 +15,22 @@ from tensorflow.keras.optimizers import Optimizer
_Loss: TypeAlias = str | tf.keras.losses.Loss | Callable[[TensorCompatible, TensorCompatible], tf.Tensor]
_Metric: TypeAlias = str | tf.keras.metrics.Metric | Callable[[TensorCompatible, TensorCompatible], tf.Tensor] | None
class Model(Layer[_InputT, _OutputT], tf.Module):
# Missing keras.src.backend.tensorflow.trainer.TensorFlowTrainer as a base class, which is not exposed by tensorflow
class Model(Layer[_InputT, _OutputT]):
_train_counter: tf.Variable
_test_counter: tf.Variable
optimizer: Optimizer | None
loss: tf.keras.losses.Loss | dict[str, tf.keras.losses.Loss]
# This is actually TensorFlowTrainer.loss
@deprecated("Instead, use `model.compute_loss(x, y, y_pred, sample_weight)`.")
def loss(
self, y: TensorCompatible | None, y_pred: TensorCompatible | None, sample_weight: Incomplete | None = None
) -> tf.Tensor | None: ...
stop_training: bool
def __new__(cls, *args: Any, **kwargs: Any) -> Model[_InputT, _OutputT]: ...
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def __setattr__(self, name: str, value: Any) -> None: ...
def __reduce__(self): ...
def __deepcopy__(self, memo): ...
def build(self, input_shape: ShapeLike) -> None: ...
def __call__(self, inputs: _InputT, *, training: bool = False, mask: TensorCompatible | None = None) -> _OutputT: ...
def call(self, inputs: _InputT, training: bool | None = None, mask: TensorCompatible | None = None) -> _OutputT: ...
@@ -36,14 +39,13 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
self,
optimizer: Optimizer | str = "rmsprop",
loss: ContainerGeneric[_Loss] | None = None,
metrics: ContainerGeneric[_Metric] | None = None,
loss_weights: ContainerGeneric[float] | None = None,
metrics: ContainerGeneric[_Metric] | None = None,
weighted_metrics: ContainerGeneric[_Metric] | None = None,
run_eagerly: bool | None = None,
steps_per_execution: int | Literal["auto"] | None = None,
jit_compile: bool | None = None,
pss_evaluation_shards: int | Literal["auto"] = 0,
**kwargs: Any,
run_eagerly: bool = False,
steps_per_execution: int | Literal["auto"] = 1,
jit_compile: bool | Literal["auto"] = "auto",
auto_scale_loss: bool | None = True,
) -> None: ...
@property
def metrics(self) -> list[Incomplete]: ...
@@ -54,10 +56,6 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
@property
def run_eagerly(self) -> bool: ...
@property
def autotune_steps_per_execution(self) -> bool: ...
@property
def steps_per_execution(self) -> int | None: ... # Returns None for a non-compiled model.
@property
def jit_compile(self) -> bool: ...
@property
def distribute_reduction_method(self) -> Incomplete | Literal["auto"]: ...
@@ -70,7 +68,7 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
sample_weight: Incomplete | None = None,
) -> tf.Tensor | None: ...
def compute_metrics(
self, x: TensorCompatible, y: TensorCompatible, y_pred: TensorCompatible, sample_weight
self, x: TensorCompatible, y: TensorCompatible, y_pred: TensorCompatible, sample_weight: Incomplete | None = None
) -> dict[str, float]: ...
def get_metrics_result(self) -> dict[str, float]: ...
def make_train_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], dict[str, float]]: ...
@@ -92,9 +90,6 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
validation_steps: int | None = None,
validation_batch_size: int | None = None,
validation_freq: int | Container[int] = 1,
max_queue_size: int = 10,
workers: int = 1,
use_multiprocessing: bool = False,
) -> tf.keras.callbacks.History: ...
def test_step(self, data: TensorCompatible) -> dict[str, float]: ...
def make_test_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], dict[str, float]]: ...
@@ -107,9 +102,6 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
sample_weight: npt.NDArray[np.float_] | None = None,
steps: int | None = None,
callbacks: list[tf.keras.callbacks.Callback] | None = None,
max_queue_size: int = 10,
workers: int = 1,
use_multiprocessing: bool = False,
return_dict: bool = False,
**kwargs: Any,
) -> float | list[float]: ...
@@ -122,9 +114,6 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
verbose: Literal["auto", 0, 1, 2] = "auto",
steps: int | None = None,
callbacks: list[tf.keras.callbacks.Callback] | None = None,
max_queue_size: int = 10,
workers: int = 1,
use_multiprocessing: bool = False,
) -> _OutputT: ...
def reset_metrics(self) -> None: ...
def train_on_batch(
@@ -133,7 +122,6 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
y: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete] | None = None,
sample_weight: npt.NDArray[np.float_] | None = None,
class_weight: dict[int, float] | None = None,
reset_metrics: bool = True,
return_dict: bool = False,
) -> float | list[float]: ...
def test_on_batch(
@@ -141,77 +129,22 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
x: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete],
y: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete] | None = None,
sample_weight: npt.NDArray[np.float_] | None = None,
reset_metrics: bool = True,
return_dict: bool = False,
) -> float | list[float]: ...
def predict_on_batch(self, x: Iterator[_InputT]) -> npt.NDArray[Incomplete]: ...
def fit_generator(
self,
generator: Iterator[Incomplete],
steps_per_epoch: int | None = None,
epochs: int = 1,
verbose: Literal["auto", 0, 1, 2] = 1,
callbacks: list[tf.keras.callbacks.Callback] | None = None,
validation_data: TensorCompatible | tf.data.Dataset[Any] | None = None,
validation_steps: int | None = None,
validation_freq: int | Container[int] = 1,
class_weight: dict[int, float] | None = None,
max_queue_size: int = 10,
workers: int = 1,
use_multiprocessing: bool = False,
shuffle: bool = True,
initial_epoch: int = 0,
) -> tf.keras.callbacks.History: ...
def evaluate_generator(
self,
generator: Iterator[Incomplete],
steps: int | None = None,
callbacks: list[tf.keras.callbacks.Callback] | None = None,
max_queue_size: int = 10,
workers: int = 1,
use_multiprocessing: bool = False,
verbose: Literal["auto", 0, 1, 2] = 0,
) -> float | list[float]: ...
def predict_generator(
self,
generator: Iterator[Incomplete],
steps: int | None = None,
callbacks: list[tf.keras.callbacks.Callback] | None = None,
max_queue_size: int = 10,
workers: int = 1,
use_multiprocessing: bool = False,
verbose: Literal["auto", 0, 1, 2] = 0,
) -> _OutputT: ...
@property
def trainable_weights(self) -> list[Variable]: ...
@property
def non_trainable_weights(self) -> list[Variable]: ...
def get_weights(self): ...
def save(
self, filepath: str | Path, overwrite: bool = True, save_format: Literal["keras", "tf", "h5"] | None = None, **kwargs: Any
) -> None: ...
def save_weights(
self,
filepath: str | Path,
overwrite: bool = True,
save_format: Literal["tf", "h5"] | None = None,
options: tf.train.CheckpointOptions | None = None,
) -> None: ...
def load_weights(
self,
filepath: str | Path,
skip_mismatch: bool = False,
by_name: bool = False,
options: None | tensorflow.train.CheckpointOptions = None,
) -> None: ...
def save(self, filepath: str | Path, overwrite: bool = True) -> None: ...
def save_weights(self, filepath: str | Path, overwrite: bool = True) -> None: ...
# kwargs are from keras.saving.saving_api.load_weights
def load_weights(self, filepath: str | Path, skip_mismatch: bool = False, *, by_name: bool = False) -> None: ...
def get_config(self) -> dict[str, Any]: ...
@classmethod
def from_config(cls, config: dict[str, Any], custom_objects: Incomplete | None = None) -> Self: ...
def to_json(self, **kwargs: Any) -> str: ...
def to_yaml(self, **kwargs: Any) -> str: ...
def reset_states(self) -> None: ...
@property
def state_updates(self) -> list[Incomplete]: ...
@property
def weights(self) -> list[Variable]: ...
def summary(
@@ -226,10 +159,8 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
@property
def layers(self) -> list[Layer[Incomplete, Incomplete]]: ...
def get_layer(self, name: str | None = None, index: int | None = None) -> Layer[Incomplete, Incomplete]: ...
def get_weight_paths(self) -> dict[str, tf.Variable]: ...
def get_compile_config(self) -> dict[str, Any]: ...
def compile_from_config(self, config: dict[str, Any]) -> Self: ...
def export(self, filepath: str | Path) -> None: ...
def save_spec(self, dynamic_batch: bool = True) -> tuple[tuple[tf.TensorSpec, ...], dict[str, tf.TensorSpec]] | None: ...
def export(self, filepath: str | Path, format: str = "tf_saved_model") -> None: ...
def __getattr__(name: str) -> Incomplete: ...

View File

@@ -1,17 +1,13 @@
from _typeshed import Incomplete
from abc import abstractmethod
from collections.abc import Callable, Iterable
from typing import Any
from typing_extensions import Self, TypeAlias
from typing_extensions import TypeAlias
import tensorflow as tf
from tensorflow._aliases import Gradients
from tensorflow.keras.optimizers import schedules as schedules
from tensorflow.python.trackable.base import Trackable
_Initializer: TypeAlias = str | Callable[[], tf.Tensor] | dict[str, Any]
_Shape: TypeAlias = tf.TensorShape | Iterable[int | None]
_Dtype: TypeAlias = tf.DType | str | None
_LearningRate: TypeAlias = float | tf.Tensor | schedules.LearningRateSchedule | Callable[[], float | tf.Tensor]
_GradientAggregator: TypeAlias = Callable[[list[tuple[Gradients, tf.Variable]]], list[tuple[Gradients, tf.Variable]]] | None
_GradientTransformer: TypeAlias = (
@@ -33,49 +29,6 @@ class Optimizer(Trackable):
gradient_transformers: _GradientTransformer = None,
**kwargs: Any,
) -> None: ...
def _create_all_weights(self, var_list: Iterable[tf.Variable]) -> None: ...
@property
def iterations(self) -> tf.Variable: ...
@iterations.setter
def iterations(self, variable: tf.Variable) -> None: ...
def add_slot(
self, var: tf.Variable, slot_name: str, initializer: _Initializer = "zeros", shape: tf.TensorShape | None = None
) -> tf.Variable: ...
def add_weight(
self,
name: str,
shape: _Shape,
dtype: _Dtype = None,
initializer: _Initializer = "zeros",
trainable: None | bool = None,
synchronization: tf.VariableSynchronization = ...,
aggregation: tf.VariableAggregation = ...,
) -> tf.Variable: ...
def apply_gradients(
self,
grads_and_vars: Iterable[tuple[Gradients, tf.Variable]],
name: str | None = None,
experimental_aggregate_gradients: bool = True,
) -> tf.Operation | None: ...
@classmethod
def from_config(cls, config: dict[str, Any], custom_objects: dict[str, type] | None = None) -> Self: ...
# Missing ABC is intentional as class is not abstract at runtime.
@abstractmethod
def get_config(self) -> dict[str, Any]: ...
def get_slot(self, var: tf.Variable, slot_name: str) -> tf.Variable: ...
def get_slot_names(self) -> list[str]: ...
def get_gradients(self, loss: tf.Tensor, params: list[tf.Variable]) -> list[Gradients]: ...
def minimize(
self,
loss: tf.Tensor | Callable[[], tf.Tensor],
var_list: list[tf.Variable] | tuple[tf.Variable, ...] | Callable[[], list[tf.Variable] | tuple[tf.Variable, ...]],
grad_loss: tf.Tensor | None = None,
name: str | None = None,
tape: tf.GradientTape | None = None,
) -> tf.Operation: ...
def variables(self) -> list[tf.Variable]: ...
@property
def weights(self) -> list[tf.Variable]: ...
class Adam(Optimizer):
def __init__(
@@ -88,7 +41,6 @@ class Adam(Optimizer):
name: str = "Adam",
**kwargs: Any,
) -> None: ...
def get_config(self) -> dict[str, Any]: ...
class Adagrad(Optimizer):
_initial_accumulator_value: float
@@ -101,12 +53,10 @@ class Adagrad(Optimizer):
name: str = "Adagrad",
**kwargs: Any,
) -> None: ...
def get_config(self) -> dict[str, Any]: ...
class SGD(Optimizer):
def __init__(
self, learning_rate: _LearningRate = 0.01, momentum: float = 0.0, nesterov: bool = False, name: str = "SGD", **kwargs: Any
) -> None: ...
def get_config(self) -> dict[str, Any]: ...
def __getattr__(name: str) -> Incomplete: ...

View File

@@ -19,7 +19,7 @@ class PiecewiseConstantDecay(LearningRateSchedule):
self,
boundaries: Sequence[tf.Tensor] | Sequence[float],
values: Sequence[float] | Sequence[tf.Tensor],
name: str | None = None,
name: str = "PiecewiseConstant",
) -> None: ...
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
def get_config(self) -> dict[str, Any]: ...
@@ -33,7 +33,7 @@ class InverseTimeDecay(LearningRateSchedule):
decay_steps: int,
decay_rate: float,
staircase: bool = False,
name: str | None = None,
name: str = "InverseTimeDecay",
) -> None: ...
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
def get_config(self) -> dict[str, Any]: ...
@@ -48,7 +48,7 @@ class PolynomialDecay(LearningRateSchedule):
end_learning_rate: float | tf.Tensor = 0.0001,
power: float = 1.0,
cycle: bool = False,
name: str | None = None,
name: str = "PolynomialDecay",
) -> None: ...
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
def get_config(self) -> dict[str, Any]: ...
@@ -61,7 +61,7 @@ class CosineDecay(LearningRateSchedule):
initial_learning_rate: float | tf.Tensor,
decay_steps: int,
alpha: float | tf.Tensor = 0.0,
name: str | None = None,
name: str = "CosineDecay",
warmup_target: int | tf.Tensor | None = None, # float32 or float64 Tensor
warmup_steps: int | tf.Tensor = 0, # int32 or int64 Tensor
) -> None: ...
@@ -78,7 +78,7 @@ class CosineDecayRestarts(LearningRateSchedule):
t_mul: float | tf.Tensor = 2.0,
m_mul: float | tf.Tensor = 1.0,
alpha: float | tf.Tensor = 0.0,
name: str | None = None,
name: str = "SGDRDecay",
) -> None: ...
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
def get_config(self) -> dict[str, Any]: ...
@@ -92,14 +92,12 @@ class ExponentialDecay(LearningRateSchedule):
decay_steps: int | tf.Tensor,
decay_rate: float | tf.Tensor,
staircase: bool = False,
name: str | None = None,
name: str = "ExponentialDecay",
) -> None: ...
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
def get_config(self) -> dict[str, Any]: ...
@classmethod
def from_config(cls, config: dict[str, Any]) -> Self: ...
def deserialize(
config: dict[str, Any], custom_objects: dict[str, type] | None = None, use_legacy_format: bool = False
) -> LearningRateSchedule: ...
def serialize(learning_rate_schedule: LearningRateSchedule, use_legacy_format: bool = False) -> dict[str, Any]: ...
def deserialize(config: dict[str, Any], custom_objects: dict[str, type] | None = None) -> LearningRateSchedule: ...
def serialize(learning_rate_schedule: LearningRateSchedule) -> dict[str, Any]: ...

View File

@@ -1,3 +1,4 @@
from _typeshed import Incomplete
from collections.abc import Callable
from typing import Any, overload
from typing_extensions import Self, TypeAlias
@@ -18,4 +19,4 @@ def get(identifier: None) -> None: ...
def get(identifier: str | dict[str, Any] | Regularizer) -> Regularizer: ...
@overload
def get(identifier: Callable[[Tensor], Tensor]) -> Callable[[Tensor], Tensor]: ...
def __getattr__(name: str) -> Any: ...
def __getattr__(name: str) -> Incomplete: ...

View File

@@ -19,6 +19,8 @@ def matmul(
a_is_sparse: _bool = False,
b_is_sparse: _bool = False,
output_type: DTypeLike | None = None,
grad_a: _bool = False,
grad_b: _bool = False,
name: str | None = None,
) -> Tensor: ...
@overload
@@ -32,6 +34,8 @@ def matmul(
a_is_sparse: _bool = False,
b_is_sparse: _bool = False,
output_type: DTypeLike | None = None,
grad_a: _bool = False,
grad_b: _bool = False,
name: str | None = None,
) -> RaggedTensor: ...
def set_diag(

View File

@@ -1,3 +1,4 @@
from _typeshed import Incomplete
from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any, Generic, Literal, TypeVar
@@ -6,7 +7,7 @@ from typing_extensions import ParamSpec, TypeAlias
import tensorflow as tf
from tensorflow.python.training.tracking.autotrackable import AutoTrackable
from tensorflow.saved_model.experimental import VariablePolicy
from tensorflow.types.experimental import ConcreteFunction, GenericFunction
from tensorflow.types.experimental import ConcreteFunction, PolymorphicFunction
_P = ParamSpec("_P")
_R = TypeVar("_R", covariant=True)
@@ -39,15 +40,17 @@ class SaveOptions:
"namespace_whitelist",
"save_debug_info",
"function_aliases",
"experimental_debug_stripper",
"experimental_io_device",
"experimental_variable_policy",
"experimental_custom_gradients",
"experimental_image_format",
"experimental_skip_saver",
"experimental_sharding_callback",
)
namespace_whitelist: list[str]
save_debug_info: bool
function_aliases: dict[str, tf.types.experimental.GenericFunction[..., object]]
function_aliases: dict[str, PolymorphicFunction[..., object]]
experimental_io_device: str
experimental_variable_policy: VariablePolicy
experimental_custom_gradients: bool
@@ -57,12 +60,14 @@ class SaveOptions:
self,
namespace_whitelist: list[str] | None = None,
save_debug_info: bool = False,
function_aliases: Mapping[str, tf.types.experimental.GenericFunction[..., object]] | None = None,
function_aliases: Mapping[str, PolymorphicFunction[..., object]] | None = None,
experimental_debug_stripper: bool = False,
experimental_io_device: str | None = None,
experimental_variable_policy: str | VariablePolicy | None = None,
experimental_custom_gradients: bool = True,
experimental_image_format: bool = False,
experimental_skip_saver: bool = False,
experimental_sharding_callback: Incomplete | None = None,
) -> None: ...
def contains_saved_model(export_dir: str | Path) -> bool: ...
@@ -80,7 +85,7 @@ def load(
export_dir: str, tags: str | Sequence[str] | None = None, options: LoadOptions | None = None
) -> _LoadedModel[..., Any]: ...
_TF_Function: TypeAlias = ConcreteFunction[..., object] | GenericFunction[..., object]
_TF_Function: TypeAlias = ConcreteFunction[..., object] | PolymorphicFunction[..., object]
def save(
obj: tf.Module,

View File

@@ -7,6 +7,7 @@ from typing_extensions import Self
import tensorflow as tf
from tensorflow._aliases import FloatArray, IntArray
from tensorflow.core.framework.graph_pb2 import GraphDef
from tensorflow.experimental.dtensor import Mesh
class SummaryWriter(metaclass=abc.ABCMeta):
@@ -36,7 +37,7 @@ def create_file_writer(
) -> SummaryWriter: ...
def create_noop_writer() -> SummaryWriter: ...
def flush(writer: SummaryWriter | None = None, name: str | None = None) -> tf.Operation: ...
def graph(graph_data: tf.Graph | tf.compat.v1.GraphDef) -> bool: ...
def graph(graph_data: tf.Graph | GraphDef) -> bool: ...
def histogram(
name: str, data: tf.Tensor, step: int | None = None, buckets: int | None = None, description: str | None = None
) -> bool: ...
@@ -54,7 +55,7 @@ def should_record_summaries() -> bool: ...
def text(name: str, data: str | tf.Tensor, step: int | tf.Tensor | None = None, description: str | None = None) -> bool: ...
def trace_export(name: str, step: int | tf.Tensor | None = None, profiler_outdir: str | None = None) -> None: ...
def trace_off() -> None: ...
def trace_on(graph: bool = True, profiler: bool = False) -> None: ...
def trace_on(graph: bool = True, profiler: bool = False, profiler_outdir: str | None = None) -> None: ...
def write(
tag: str, tensor: tf.Tensor, step: int | tf.Tensor | None = None, metadata: Incomplete | None = None, name: str | None = None
) -> bool: ...

View File

@@ -30,6 +30,8 @@ class CheckpointOptions:
experimental_enable_async_checkpoint: bool = False,
experimental_write_callbacks: None | list[Callable[[str], object] | Callable[[], object]] = None,
enable_async: bool = False,
experimental_skip_slot_variables: bool = False,
experimental_sharding_callback: Incomplete | None = None,
) -> None: ...
_T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str])

View File

@@ -1,253 +0,0 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
This file defines protos that store the results of autotuning various
operations.
They are in proto format because we want to log them structured. They offer
tremendous statistical, testing, and debugging value.
"""
import builtins
import collections.abc
import sys
import typing
import google.protobuf.any_pb2
import google.protobuf.descriptor
import google.protobuf.duration_pb2
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import tensorflow.tsl.protobuf.dnn_pb2
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class CudnnVersion(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
MAJOR_FIELD_NUMBER: builtins.int
MINOR_FIELD_NUMBER: builtins.int
PATCH_FIELD_NUMBER: builtins.int
major: builtins.int
minor: builtins.int
patch: builtins.int
def __init__(
self,
*,
major: builtins.int | None = ...,
minor: builtins.int | None = ...,
patch: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["major", b"major", "minor", b"minor", "patch", b"patch"]) -> None: ...
global___CudnnVersion = CudnnVersion
@typing.final
class ComputeCapability(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
MAJOR_FIELD_NUMBER: builtins.int
MINOR_FIELD_NUMBER: builtins.int
major: builtins.int
minor: builtins.int
def __init__(
self,
*,
major: builtins.int | None = ...,
minor: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["major", b"major", "minor", b"minor"]) -> None: ...
global___ComputeCapability = ComputeCapability
@typing.final
class AutotuneResult(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _FailureKind:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _FailureKindEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[AutotuneResult._FailureKind.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
UNKNOWN: AutotuneResult._FailureKind.ValueType # 0
REDZONE_MODIFIED: AutotuneResult._FailureKind.ValueType # 1
"""Algorithm wrote memory outside its output buffers."""
WRONG_RESULT: AutotuneResult._FailureKind.ValueType # 2
"""Algorithm gave a different result from a reference algorithm."""
DISQUALIFIED: AutotuneResult._FailureKind.ValueType # 3
"""Algorithm was rejected for failing to run or for known bugs."""
class FailureKind(_FailureKind, metaclass=_FailureKindEnumTypeWrapper): ...
UNKNOWN: AutotuneResult.FailureKind.ValueType # 0
REDZONE_MODIFIED: AutotuneResult.FailureKind.ValueType # 1
"""Algorithm wrote memory outside its output buffers."""
WRONG_RESULT: AutotuneResult.FailureKind.ValueType # 2
"""Algorithm gave a different result from a reference algorithm."""
DISQUALIFIED: AutotuneResult.FailureKind.ValueType # 3
"""Algorithm was rejected for failing to run or for known bugs."""
@typing.final
class FailureResult(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KIND_FIELD_NUMBER: builtins.int
MSG_FIELD_NUMBER: builtins.int
REFERENCE_CONV_FIELD_NUMBER: builtins.int
REFERENCE_GEMM_FIELD_NUMBER: builtins.int
REFERENCE_CUDA_CONV_PLAN_FIELD_NUMBER: builtins.int
REFERENCE_ALGORITHM_FIELD_NUMBER: builtins.int
BUFFER_ADDRESS_FIELD_NUMBER: builtins.int
kind: global___AutotuneResult.FailureKind.ValueType
msg: builtins.str
buffer_address: builtins.int
@property
def reference_conv(self) -> global___AutotuneResult.ConvKey: ...
@property
def reference_gemm(self) -> global___AutotuneResult.GemmKey: ...
@property
def reference_cuda_conv_plan(self) -> global___AutotuneResult.CudaConvPlanKey: ...
@property
def reference_algorithm(self) -> tensorflow.tsl.protobuf.dnn_pb2.AlgorithmProto: ...
def __init__(
self,
*,
kind: global___AutotuneResult.FailureKind.ValueType | None = ...,
msg: builtins.str | None = ...,
reference_conv: global___AutotuneResult.ConvKey | None = ...,
reference_gemm: global___AutotuneResult.GemmKey | None = ...,
reference_cuda_conv_plan: global___AutotuneResult.CudaConvPlanKey | None = ...,
reference_algorithm: tensorflow.tsl.protobuf.dnn_pb2.AlgorithmProto | None = ...,
buffer_address: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["key", b"key", "reference_algorithm", b"reference_algorithm", "reference_conv", b"reference_conv", "reference_cuda_conv_plan", b"reference_cuda_conv_plan", "reference_gemm", b"reference_gemm"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["buffer_address", b"buffer_address", "key", b"key", "kind", b"kind", "msg", b"msg", "reference_algorithm", b"reference_algorithm", "reference_conv", b"reference_conv", "reference_cuda_conv_plan", b"reference_cuda_conv_plan", "reference_gemm", b"reference_gemm"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["key", b"key"]) -> typing.Literal["reference_conv", "reference_gemm", "reference_cuda_conv_plan", "reference_algorithm"] | None: ...
@typing.final
class ConvKey(google.protobuf.message.Message):
"""Legacy and unused in new data; superseded by AlgorithmProto."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ALGORITHM_FIELD_NUMBER: builtins.int
TENSOR_OPS_ENABLED_FIELD_NUMBER: builtins.int
algorithm: builtins.int
tensor_ops_enabled: builtins.bool
def __init__(
self,
*,
algorithm: builtins.int | None = ...,
tensor_ops_enabled: builtins.bool | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["algorithm", b"algorithm", "tensor_ops_enabled", b"tensor_ops_enabled"]) -> None: ...
@typing.final
class GemmKey(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ALGORITHM_FIELD_NUMBER: builtins.int
algorithm: builtins.int
def __init__(
self,
*,
algorithm: builtins.int | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["algorithm", b"algorithm"]) -> None: ...
@typing.final
class CudaConvPlanKey(google.protobuf.message.Message):
"""Legacy and unused in new data; superseded by AlgorithmProto."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
EXEC_PLAN_ID_FIELD_NUMBER: builtins.int
exec_plan_id: builtins.str
def __init__(
self,
*,
exec_plan_id: builtins.str | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["exec_plan_id", b"exec_plan_id"]) -> None: ...
SCRATCH_BYTES_FIELD_NUMBER: builtins.int
RUN_TIME_FIELD_NUMBER: builtins.int
FAILURE_FIELD_NUMBER: builtins.int
CONV_FIELD_NUMBER: builtins.int
GEMM_FIELD_NUMBER: builtins.int
CUDA_CONV_PLAN_FIELD_NUMBER: builtins.int
ALGORITHM_FIELD_NUMBER: builtins.int
scratch_bytes: builtins.int
@property
def run_time(self) -> google.protobuf.duration_pb2.Duration: ...
@property
def failure(self) -> global___AutotuneResult.FailureResult: ...
@property
def conv(self) -> global___AutotuneResult.ConvKey: ...
@property
def gemm(self) -> global___AutotuneResult.GemmKey: ...
@property
def cuda_conv_plan(self) -> global___AutotuneResult.CudaConvPlanKey: ...
@property
def algorithm(self) -> tensorflow.tsl.protobuf.dnn_pb2.AlgorithmProto: ...
def __init__(
self,
*,
scratch_bytes: builtins.int | None = ...,
run_time: google.protobuf.duration_pb2.Duration | None = ...,
failure: global___AutotuneResult.FailureResult | None = ...,
conv: global___AutotuneResult.ConvKey | None = ...,
gemm: global___AutotuneResult.GemmKey | None = ...,
cuda_conv_plan: global___AutotuneResult.CudaConvPlanKey | None = ...,
algorithm: tensorflow.tsl.protobuf.dnn_pb2.AlgorithmProto | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["algorithm", b"algorithm", "conv", b"conv", "cuda_conv_plan", b"cuda_conv_plan", "failure", b"failure", "gemm", b"gemm", "key", b"key", "run_time", b"run_time"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["algorithm", b"algorithm", "conv", b"conv", "cuda_conv_plan", b"cuda_conv_plan", "failure", b"failure", "gemm", b"gemm", "key", b"key", "run_time", b"run_time", "scratch_bytes", b"scratch_bytes"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["key", b"key"]) -> typing.Literal["conv", "gemm", "cuda_conv_plan", "algorithm"] | None: ...
global___AutotuneResult = AutotuneResult
@typing.final
class AutotuningLog(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
INSTR_FIELD_NUMBER: builtins.int
RESULTS_FIELD_NUMBER: builtins.int
CUDNN_VERSION_FIELD_NUMBER: builtins.int
COMPUTE_CAPABILITY_FIELD_NUMBER: builtins.int
DEVICE_PCI_BUS_ID_FIELD_NUMBER: builtins.int
BLAS_VERSION_FIELD_NUMBER: builtins.int
device_pci_bus_id: builtins.str
"""stream_executor::DeviceDescription::pci_bus_id."""
blas_version: builtins.str
@property
def instr(self) -> google.protobuf.any_pb2.Any: ...
@property
def results(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___AutotuneResult]:
"""Records all auto-tuning results per algorithm."""
@property
def cudnn_version(self) -> global___CudnnVersion: ...
@property
def compute_capability(self) -> global___ComputeCapability: ...
def __init__(
self,
*,
instr: google.protobuf.any_pb2.Any | None = ...,
results: collections.abc.Iterable[global___AutotuneResult] | None = ...,
cudnn_version: global___CudnnVersion | None = ...,
compute_capability: global___ComputeCapability | None = ...,
device_pci_bus_id: builtins.str | None = ...,
blas_version: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["compute_capability", b"compute_capability", "cudnn_version", b"cudnn_version", "instr", b"instr"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["blas_version", b"blas_version", "compute_capability", b"compute_capability", "cudnn_version", b"cudnn_version", "device_pci_bus_id", b"device_pci_bus_id", "instr", b"instr", "results", b"results"]) -> None: ...
global___AutotuningLog = AutotuningLog

View File

@@ -52,6 +52,8 @@ class CoordinationServiceConfig(google.protobuf.message.Message):
SHUTDOWN_BARRIER_TIMEOUT_IN_MS_FIELD_NUMBER: builtins.int
AGENT_DESTRUCTION_WITHOUT_SHUTDOWN_FIELD_NUMBER: builtins.int
RECOVERABLE_JOBS_FIELD_NUMBER: builtins.int
ALLOW_NEW_INCARNATION_TO_RECONNECT_FIELD_NUMBER: builtins.int
FORCE_DISABLE_FIELD_NUMBER: builtins.int
service_type: builtins.str
"""Type of coordination service implementation to enable.
For example, setting the service type as "standalone" starts a service
@@ -82,6 +84,17 @@ class CoordinationServiceConfig(google.protobuf.message.Message):
find out about the disconnecte agent via stale heartbeats. Used for
testing.
"""
allow_new_incarnation_to_reconnect: builtins.bool
"""If a task restarts with a new incarnation, we may allow it to reconnect
silently. This is useful when we know that a task can immediately resume
work upon re-connecting to the service.
"""
force_disable: builtins.bool
"""Disables coordination service.
Some libraries enable coordination service by default even if the user did
not specify any config. This field allows users to explicitly disable
coordination service under all situations.
"""
@property
def coordinated_job_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___CoordinatedJob]: ...
@property
@@ -104,7 +117,9 @@ class CoordinationServiceConfig(google.protobuf.message.Message):
shutdown_barrier_timeout_in_ms: builtins.int | None = ...,
agent_destruction_without_shutdown: builtins.bool | None = ...,
recoverable_jobs: collections.abc.Iterable[builtins.str] | None = ...,
allow_new_incarnation_to_reconnect: builtins.bool | None = ...,
force_disable: builtins.bool | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["agent_destruction_without_shutdown", b"agent_destruction_without_shutdown", "cluster_register_timeout_in_ms", b"cluster_register_timeout_in_ms", "coordinated_job_list", b"coordinated_job_list", "enable_health_check", b"enable_health_check", "heartbeat_timeout_in_ms", b"heartbeat_timeout_in_ms", "recoverable_jobs", b"recoverable_jobs", "service_leader", b"service_leader", "service_type", b"service_type", "shutdown_barrier_timeout_in_ms", b"shutdown_barrier_timeout_in_ms"]) -> None: ...
def ClearField(self, field_name: typing.Literal["agent_destruction_without_shutdown", b"agent_destruction_without_shutdown", "allow_new_incarnation_to_reconnect", b"allow_new_incarnation_to_reconnect", "cluster_register_timeout_in_ms", b"cluster_register_timeout_in_ms", "coordinated_job_list", b"coordinated_job_list", "enable_health_check", b"enable_health_check", "force_disable", b"force_disable", "heartbeat_timeout_in_ms", b"heartbeat_timeout_in_ms", "recoverable_jobs", b"recoverable_jobs", "service_leader", b"service_leader", "service_type", b"service_type", "shutdown_barrier_timeout_in_ms", b"shutdown_barrier_timeout_in_ms"]) -> None: ...
global___CoordinationServiceConfig = CoordinationServiceConfig

View File

@@ -37,6 +37,9 @@ class _DataTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumT
kBF16: _DataType.ValueType # 7
kF8E5M2: _DataType.ValueType # 8
kF8E4M3FN: _DataType.ValueType # 9
kF8E5M2FNUZ: _DataType.ValueType # 10
kF8E4M3FNUZ: _DataType.ValueType # 11
kInt64: _DataType.ValueType # 12
class DataType(_DataType, metaclass=_DataTypeEnumTypeWrapper):
"""Specifies the data type used by an operation."""
@@ -51,6 +54,9 @@ kComplexDouble: DataType.ValueType # 6
kBF16: DataType.ValueType # 7
kF8E5M2: DataType.ValueType # 8
kF8E4M3FN: DataType.ValueType # 9
kF8E5M2FNUZ: DataType.ValueType # 10
kF8E4M3FNUZ: DataType.ValueType # 11
kInt64: DataType.ValueType # 12
global___DataType = DataType
class _DataLayout:
@@ -132,6 +138,10 @@ class _FilterLayoutEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._E
"""cuDNN's NCHW_VECT_C layout with 4-elem vectors"""
kOutputInputYX32: _FilterLayout.ValueType # 5
"""cuDNN's NCHW_VECT_C layout with 32-elem vectors"""
kOutputInputYX32_CudnnReordered: _FilterLayout.ValueType # 6
"""cuDNN-specific filter reordering (using `cudnnReorderFilterAndBias`)
When the filter is reordered, so is the bias (if present).
"""
kInputYXOutput: _FilterLayout.ValueType # 3
kYXInputOutput: _FilterLayout.ValueType # 4
@@ -153,6 +163,10 @@ kOutputInputYX4: FilterLayout.ValueType # 2
"""cuDNN's NCHW_VECT_C layout with 4-elem vectors"""
kOutputInputYX32: FilterLayout.ValueType # 5
"""cuDNN's NCHW_VECT_C layout with 32-elem vectors"""
kOutputInputYX32_CudnnReordered: FilterLayout.ValueType # 6
"""cuDNN-specific filter reordering (using `cudnnReorderFilterAndBias`)
When the filter is reordered, so is the bias (if present).
"""
kInputYXOutput: FilterLayout.ValueType # 3
kYXInputOutput: FilterLayout.ValueType # 4
global___FilterLayout = FilterLayout
@@ -241,6 +255,7 @@ class _ConvolutionKindEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper
BACKWARD_FILTER: _ConvolutionKind.ValueType # 2
BACKWARD_DATA: _ConvolutionKind.ValueType # 3
FORWARD_BIAS_ACTIVATION: _ConvolutionKind.ValueType # 4
FORWARD_GRAPH: _ConvolutionKind.ValueType # 5
class ConvolutionKind(_ConvolutionKind, metaclass=_ConvolutionKindEnumTypeWrapper): ...
@@ -249,8 +264,27 @@ FORWARD: ConvolutionKind.ValueType # 1
BACKWARD_FILTER: ConvolutionKind.ValueType # 2
BACKWARD_DATA: ConvolutionKind.ValueType # 3
FORWARD_BIAS_ACTIVATION: ConvolutionKind.ValueType # 4
FORWARD_GRAPH: ConvolutionKind.ValueType # 5
global___ConvolutionKind = ConvolutionKind
class _FusedMHAKind:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _FusedMHAKindEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_FusedMHAKind.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
BMM1_OUTPUT_UNKNOWN: _FusedMHAKind.ValueType # 0
BMM1_OUTPUT_INPUT_TYPE: _FusedMHAKind.ValueType # 1
BMM1_OUTPUT_FLOAT: _FusedMHAKind.ValueType # 2
class FusedMHAKind(_FusedMHAKind, metaclass=_FusedMHAKindEnumTypeWrapper):
"""FusedMHAKind kind"""
BMM1_OUTPUT_UNKNOWN: FusedMHAKind.ValueType # 0
BMM1_OUTPUT_INPUT_TYPE: FusedMHAKind.ValueType # 1
BMM1_OUTPUT_FLOAT: FusedMHAKind.ValueType # 2
global___FusedMHAKind = FusedMHAKind
@typing.final
class TensorDescriptorProto(google.protobuf.message.Message):
"""Generic tensor representation."""

View File

@@ -0,0 +1,37 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import typing
import google.protobuf.descriptor
import google.protobuf.message
import tensorflow.tsl.protobuf.error_codes_pb2
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class StatusProto(google.protobuf.message.Message):
"""Wire-format for Status.
Next tag: 3
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
CODE_FIELD_NUMBER: builtins.int
MESSAGE_FIELD_NUMBER: builtins.int
code: tensorflow.tsl.protobuf.error_codes_pb2.Code.ValueType
"""Status code as defined in tensorflow/tsl/protobuf/error_codes.proto."""
message: builtins.str
"""Detail error message."""
def __init__(
self,
*,
code: tensorflow.tsl.protobuf.error_codes_pb2.Code.ValueType | None = ...,
message: builtins.str | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["code", b"code", "message", b"message"]) -> None: ...
global___StatusProto = StatusProto

View File

@@ -0,0 +1 @@
from tensorflow.types import experimental as experimental

View File

@@ -15,7 +15,7 @@ class Callable(Generic[_P, _R], metaclass=abc.ABCMeta):
class ConcreteFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
class GenericFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
class PolymorphicFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
@overload
@abc.abstractmethod
def get_concrete_function(self, *args: _P.args, **kwargs: _P.kwargs) -> ConcreteFunction[_P, _R]: ...
@@ -26,4 +26,6 @@ class GenericFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
) -> ConcreteFunction[_P, _R]: ...
def experimental_get_compiler_ir(self, *args, **kwargs): ...
GenericFunction = PolymorphicFunction
def __getattr__(name: str) -> Incomplete: ...