mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-07 12:44:28 +08:00
Bump tensorflow to 2.16.* (#11696)
This commit is contained in:
@@ -49,15 +49,9 @@ tensorflow._aliases
|
||||
# tf.initializers at runtime is <module 'keras.api._v2.keras.initializers' from '...'>
|
||||
tensorflow.initializers
|
||||
# Other cursed import magic similar to the one above.
|
||||
tensorflow.keras.layers.preprocessing
|
||||
tensorflow.keras.layers.preprocessing.index_lookup
|
||||
# Another cursed import magic similar to the one above.
|
||||
tensorflow.distribute.coordinator
|
||||
tensorflow.distribute.experimental.coordinator
|
||||
|
||||
# Layer constructor's always have **kwargs, but only allow a few specific values. PEP 692
|
||||
# would allow us to specify this with **kwargs and remove the need for these exceptions.
|
||||
tensorflow.keras.layers.*.__init__
|
||||
|
||||
# __call__ in tensorflow classes often allow keyword usage, but
|
||||
# when you subclass those classes it is not expected to handle keyword case. As an example,
|
||||
# class MyLayer(tf.keras.layers.Layer):
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
version = "2.15.*"
|
||||
# Whenever you update version here, TENSORFLOW_VERSION should be updated
|
||||
# in scripts/sync_tensorflow_protobuf_stubs.sh and vice-versa.
|
||||
version = "2.16.*"
|
||||
upstream_repository = "https://github.com/tensorflow/tensorflow"
|
||||
# requires a version of numpy with a `py.typed` file
|
||||
requires = ["numpy>=1.20", "types-protobuf", "types-requests"]
|
||||
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 25.1 on tensorflow==2.12.1."
|
||||
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 25.1 on tensorflow==2.16.1 ."
|
||||
partial_stub = true
|
||||
|
||||
[tool.stubtest]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import abc
|
||||
from _typeshed import Incomplete, Unused
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from builtins import bool as _bool
|
||||
@@ -17,6 +18,7 @@ from tensorflow import (
|
||||
io as io,
|
||||
keras as keras,
|
||||
math as math,
|
||||
types as types,
|
||||
)
|
||||
from tensorflow._aliases import AnyArray, DTypeLike, ShapeLike, Slice, TensorCompatible
|
||||
from tensorflow.autodiff import GradientTape as GradientTape
|
||||
@@ -257,7 +259,7 @@ class IndexedSlices(metaclass=ABCMeta):
|
||||
def __neg__(self) -> IndexedSlices: ...
|
||||
def consumers(self) -> list[Operation]: ...
|
||||
|
||||
class name_scope:
|
||||
class name_scope(metaclass=abc.ABCMeta):
|
||||
def __init__(self, name: str) -> None: ...
|
||||
def __enter__(self) -> str: ...
|
||||
def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ...
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import typing
|
||||
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
import tensorflow.tsl.protobuf.autotuning_pb2
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class AutotuneResults(google.protobuf.message.Message):
|
||||
"""A collection of algorithms for particular dot/convs. Usually this is "the
|
||||
best" algorithm for the particular dot/conv, although that's not strictly
|
||||
required.
|
||||
|
||||
Users don't interact with this proto directly. It's used internally to
|
||||
facilitate ahead-of-time autotuning -- The string used by
|
||||
xla::{Serialize,Load}AutotuneResults is, internally, a serialization of this
|
||||
proto.
|
||||
|
||||
LINT.IfChange
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@typing.final
|
||||
class Entry(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
DEVICE_FIELD_NUMBER: builtins.int
|
||||
HLO_FIELD_NUMBER: builtins.int
|
||||
RESULT_FIELD_NUMBER: builtins.int
|
||||
device: builtins.str
|
||||
hlo: builtins.str
|
||||
@property
|
||||
def result(self) -> tensorflow.tsl.protobuf.autotuning_pb2.AutotuneResult:
|
||||
"""nb: These results are always tied to a particular version of
|
||||
cublas/cudnn, but this is *especially* true for cublasLt results. For
|
||||
cublasLt gemms, the result is an index into the list of candidate
|
||||
algorithms returned by cublasLt. Different version of cublasLt ->
|
||||
different list of algos -> different interpretation of results!
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device: builtins.str | None = ...,
|
||||
hlo: builtins.str | None = ...,
|
||||
result: tensorflow.tsl.protobuf.autotuning_pb2.AutotuneResult | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["result", b"result"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["device", b"device", "hlo", b"hlo", "result", b"result"]) -> None: ...
|
||||
|
||||
VERSION_FIELD_NUMBER: builtins.int
|
||||
DOTS_FIELD_NUMBER: builtins.int
|
||||
CONVS_FIELD_NUMBER: builtins.int
|
||||
version: builtins.int
|
||||
@property
|
||||
def dots(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___AutotuneResults.Entry]: ...
|
||||
@property
|
||||
def convs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___AutotuneResults.Entry]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
version: builtins.int | None = ...,
|
||||
dots: collections.abc.Iterable[global___AutotuneResults.Entry] | None = ...,
|
||||
convs: collections.abc.Iterable[global___AutotuneResults.Entry] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["convs", b"convs", "dots", b"dots", "version", b"version"]) -> None: ...
|
||||
|
||||
global___AutotuneResults = AutotuneResults
|
||||
@@ -20,6 +20,7 @@ import collections.abc
|
||||
import sys
|
||||
import typing
|
||||
|
||||
import google.protobuf.any_pb2
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
@@ -218,7 +219,7 @@ global___Kind = Kind
|
||||
@typing.final
|
||||
class HloInstructionProto(google.protobuf.message.Message):
|
||||
"""Serialization of HloInstruction.
|
||||
Next ID: 81
|
||||
Next ID: 87
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
@@ -310,8 +311,11 @@ class HloInstructionProto(google.protobuf.message.Message):
|
||||
CROSS_PROGRAM_PREFETCH_INDEX_FIELD_NUMBER: builtins.int
|
||||
PADDING_TYPE_FIELD_NUMBER: builtins.int
|
||||
CUSTOM_CALL_API_VERSION_FIELD_NUMBER: builtins.int
|
||||
ASYNC_GROUP_ID_FIELD_NUMBER: builtins.int
|
||||
ASYNC_EXECUTION_THREAD_FIELD_NUMBER: builtins.int
|
||||
K_FIELD_NUMBER: builtins.int
|
||||
LARGEST_FIELD_NUMBER: builtins.int
|
||||
STATISTICS_VIZ_FIELD_NUMBER: builtins.int
|
||||
DOT_SPARSITY_FIELD_NUMBER: builtins.int
|
||||
name: builtins.str
|
||||
opcode: builtins.str
|
||||
parameter_number: builtins.int
|
||||
@@ -420,16 +424,15 @@ class HloInstructionProto(google.protobuf.message.Message):
|
||||
TODO(b/189822916): Remove this field when all clients are migrated to the
|
||||
status-returning API.
|
||||
"""
|
||||
async_group_id: builtins.int
|
||||
"""Represents a unique identifier for an async group which consists of an
|
||||
async start, async done, and zero or more async update operations.
|
||||
Negative async_group_id is equivalent to no async group id.
|
||||
"""
|
||||
async_execution_thread: builtins.str
|
||||
"""Represents a unique execution thread name for one or more async groups.
|
||||
Each HLO module may contain a main thread and one or more parallel threads.
|
||||
Empty async_execution_thread is equivalent to main thread.
|
||||
"""
|
||||
k: builtins.int
|
||||
"""Represents the K value for top-k."""
|
||||
largest: builtins.bool
|
||||
"""Represents the largest flag for top-k."""
|
||||
@property
|
||||
def shape(self) -> tensorflow.compiler.xla.xla_data_pb2.ShapeProto: ...
|
||||
@property
|
||||
@@ -536,6 +539,16 @@ class HloInstructionProto(google.protobuf.message.Message):
|
||||
def frontend_attributes(self) -> tensorflow.compiler.xla.xla_data_pb2.FrontendAttributes:
|
||||
"""Frontend attributes to pass to the XLA backend."""
|
||||
|
||||
@property
|
||||
def statistics_viz(self) -> tensorflow.compiler.xla.xla_data_pb2.StatisticsViz:
|
||||
"""Represents the information for tracking propagation of values within HLO
|
||||
graph.
|
||||
"""
|
||||
|
||||
@property
|
||||
def dot_sparsity(self) -> tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor:
|
||||
"""Sparsity descriptor for dot operation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -605,11 +618,14 @@ class HloInstructionProto(google.protobuf.message.Message):
|
||||
cross_program_prefetch_index: builtins.int | None = ...,
|
||||
padding_type: tensorflow.compiler.xla.xla_data_pb2.PaddingType.ValueType | None = ...,
|
||||
custom_call_api_version: global___CustomCallApiVersion.ValueType | None = ...,
|
||||
async_group_id: builtins.int | None = ...,
|
||||
async_execution_thread: builtins.str | None = ...,
|
||||
k: builtins.int | None = ...,
|
||||
largest: builtins.bool | None = ...,
|
||||
statistics_viz: tensorflow.compiler.xla.xla_data_pb2.StatisticsViz | None = ...,
|
||||
dot_sparsity: tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "async_group_id", b"async_group_id", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "output_operand_aliasing", b"output_operand_aliasing", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "k", b"k", "largest", b"largest", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "output_operand_aliasing", b"output_operand_aliasing", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index"]) -> typing.Literal["cross_program_prefetch_index"] | None: ...
|
||||
|
||||
global___HloInstructionProto = HloInstructionProto
|
||||
@@ -736,7 +752,7 @@ class HloInputOutputAliasProto(google.protobuf.message.Message):
|
||||
parameter_shape_index={1, 2},
|
||||
}
|
||||
|
||||
This entry indicates that the first paremter's {1, 2} element is
|
||||
This entry indicates that the first parameter's {1, 2} element is
|
||||
aliased with the {1} element of the root instruction.
|
||||
"""
|
||||
|
||||
@@ -781,72 +797,54 @@ class HloInputOutputAliasProto(google.protobuf.message.Message):
|
||||
global___HloInputOutputAliasProto = HloInputOutputAliasProto
|
||||
|
||||
@typing.final
|
||||
class DynamicParameterBindingProto(google.protobuf.message.Message):
|
||||
class HloBufferDonorProto(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@typing.final
|
||||
class Binding(google.protobuf.message.Message):
|
||||
"""A list of bindings which indicates that the `target_param_dim_num` in
|
||||
the subshape `target_param_index` of parameter `target_param_num`
|
||||
is a dynamic dimension and its real dynamic size is represented
|
||||
by `dynamic_param_index` in parameter `dynamic_param_num`.
|
||||
class BufferDonorEntryProto(google.protobuf.message.Message):
|
||||
"""The following proto describes an input (described by parameter number and a
|
||||
ShapeIndex of the parameter) that can donate its butter to any output
|
||||
tensor. It is similar to HloInputOutputAliasProto, but without a paired
|
||||
output. For example:
|
||||
|
||||
As an example, imagine we have a program:
|
||||
|
||||
ENTRY main {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[10] parameter(1)
|
||||
ROOT root = (f32[], f32[10]) tuple(%a, %b)
|
||||
entry = {
|
||||
parameter_number=0,
|
||||
parameter_shape_index={1, 2},
|
||||
}
|
||||
|
||||
Let's say 'b' (param index 1) is a dynamic shape whose input has
|
||||
an upperbound of 10 and real size is determined at runtime.'a'
|
||||
represents the real size of b's first dimension.
|
||||
|
||||
In this case, the fields are set in the following way:
|
||||
dynamic_param_num = 1
|
||||
dynamic_param_index = {}
|
||||
target_param_num = 0
|
||||
target_param_index = {}
|
||||
target_param_dim_num = 0
|
||||
This entry indicates that the first parameter's {1, 2} element can donate
|
||||
its buffer.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
DYNAMIC_PARAM_NUM_FIELD_NUMBER: builtins.int
|
||||
DYNAMIC_PARAM_INDEX_FIELD_NUMBER: builtins.int
|
||||
TARGET_PARAM_NUM_FIELD_NUMBER: builtins.int
|
||||
TARGET_PARAM_INDEX_FIELD_NUMBER: builtins.int
|
||||
TARGET_PARAM_DIM_NUM_FIELD_NUMBER: builtins.int
|
||||
dynamic_param_num: builtins.int
|
||||
target_param_num: builtins.int
|
||||
target_param_dim_num: builtins.int
|
||||
PARAMETER_NUMBER_FIELD_NUMBER: builtins.int
|
||||
PARAMETER_SHAPE_INDEX_FIELD_NUMBER: builtins.int
|
||||
parameter_number: builtins.int
|
||||
"""Number of the parameter in entry computation."""
|
||||
@property
|
||||
def dynamic_param_index(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
|
||||
@property
|
||||
def target_param_index(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
|
||||
def parameter_shape_index(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""ShapeIndex of the parameter instruction."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dynamic_param_num: builtins.int | None = ...,
|
||||
dynamic_param_index: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
target_param_num: builtins.int | None = ...,
|
||||
target_param_index: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
target_param_dim_num: builtins.int | None = ...,
|
||||
parameter_number: builtins.int | None = ...,
|
||||
parameter_shape_index: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["dynamic_param_index", b"dynamic_param_index", "dynamic_param_num", b"dynamic_param_num", "target_param_dim_num", b"target_param_dim_num", "target_param_index", b"target_param_index", "target_param_num", b"target_param_num"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["parameter_number", b"parameter_number", "parameter_shape_index", b"parameter_shape_index"]) -> None: ...
|
||||
|
||||
ENTRIES_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def entries(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___DynamicParameterBindingProto.Binding]: ...
|
||||
def entries(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___HloBufferDonorProto.BufferDonorEntryProto]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
entries: collections.abc.Iterable[global___DynamicParameterBindingProto.Binding] | None = ...,
|
||||
entries: collections.abc.Iterable[global___HloBufferDonorProto.BufferDonorEntryProto] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["entries", b"entries"]) -> None: ...
|
||||
|
||||
global___DynamicParameterBindingProto = DynamicParameterBindingProto
|
||||
global___HloBufferDonorProto = HloBufferDonorProto
|
||||
|
||||
@typing.final
|
||||
class CrossProgramPrefetch(google.protobuf.message.Message):
|
||||
@@ -870,6 +868,101 @@ class CrossProgramPrefetch(google.protobuf.message.Message):
|
||||
|
||||
global___CrossProgramPrefetch = CrossProgramPrefetch
|
||||
|
||||
@typing.final
|
||||
class StackFrameIndexProto(google.protobuf.message.Message):
|
||||
"""Serialization of stack frames index representations.
|
||||
Stack frames index presented in four flat arrays:
|
||||
1. File names array.
|
||||
2. Function names array.
|
||||
3. File location array.
|
||||
4. Frame array.
|
||||
All reference ids in sub-protos are 1-based positions of the
|
||||
entity in the flat array.
|
||||
Ids are 1-based to keep 0 value as representation of non-set property.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@typing.final
|
||||
class FileLocation(google.protobuf.message.Message):
|
||||
"""Serialization of file position."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
FILE_NAME_ID_FIELD_NUMBER: builtins.int
|
||||
FUNCTION_NAME_ID_FIELD_NUMBER: builtins.int
|
||||
LINE_FIELD_NUMBER: builtins.int
|
||||
COLUMN_FIELD_NUMBER: builtins.int
|
||||
file_name_id: builtins.int
|
||||
"""1-based position of file name."""
|
||||
function_name_id: builtins.int
|
||||
"""1-based position of function name."""
|
||||
line: builtins.int
|
||||
"""Line number."""
|
||||
column: builtins.int
|
||||
"""Column number."""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
file_name_id: builtins.int | None = ...,
|
||||
function_name_id: builtins.int | None = ...,
|
||||
line: builtins.int | None = ...,
|
||||
column: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["column", b"column", "file_name_id", b"file_name_id", "function_name_id", b"function_name_id", "line", b"line"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class StackFrame(google.protobuf.message.Message):
|
||||
"""Serialization of frame."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
FILE_LOCATION_ID_FIELD_NUMBER: builtins.int
|
||||
PARENT_FRAME_ID_FIELD_NUMBER: builtins.int
|
||||
file_location_id: builtins.int
|
||||
"""1-based position of file location."""
|
||||
parent_frame_id: builtins.int
|
||||
"""1-based position of the parent frame."""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
file_location_id: builtins.int | None = ...,
|
||||
parent_frame_id: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["file_location_id", b"file_location_id", "parent_frame_id", b"parent_frame_id"]) -> None: ...
|
||||
|
||||
FILE_NAMES_FIELD_NUMBER: builtins.int
|
||||
FUNCTION_NAMES_FIELD_NUMBER: builtins.int
|
||||
FILE_LOCATIONS_FIELD_NUMBER: builtins.int
|
||||
STACK_FRAMES_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def file_names(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
|
||||
"""Flat index array of file names."""
|
||||
|
||||
@property
|
||||
def function_names(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
|
||||
"""Flat index array of function names."""
|
||||
|
||||
@property
|
||||
def file_locations(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___StackFrameIndexProto.FileLocation]:
|
||||
"""Flat index array of file locations."""
|
||||
|
||||
@property
|
||||
def stack_frames(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___StackFrameIndexProto.StackFrame]:
|
||||
"""Flat index array of frames."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
file_names: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
function_names: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
file_locations: collections.abc.Iterable[global___StackFrameIndexProto.FileLocation] | None = ...,
|
||||
stack_frames: collections.abc.Iterable[global___StackFrameIndexProto.StackFrame] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["file_locations", b"file_locations", "file_names", b"file_names", "function_names", b"function_names", "stack_frames", b"stack_frames"]) -> None: ...
|
||||
|
||||
global___StackFrameIndexProto = StackFrameIndexProto
|
||||
|
||||
@typing.final
|
||||
class HloModuleProto(google.protobuf.message.Message):
|
||||
"""Serialization of HloModule."""
|
||||
@@ -907,6 +1000,7 @@ class HloModuleProto(google.protobuf.message.Message):
|
||||
RELATIVE_SPEEDUP_FIELD_NUMBER: builtins.int
|
||||
PROFILE_SOURCE_FIELD_NUMBER: builtins.int
|
||||
COMPILATION_EVENT_FIELD_NUMBER: builtins.int
|
||||
FINGERPRINT_FIELD_NUMBER: builtins.int
|
||||
profile_type: global___HloModuleProto.ProfileType.ValueType
|
||||
"""The optimization profiles that this module contains."""
|
||||
relative_speedup: builtins.float
|
||||
@@ -915,6 +1009,8 @@ class HloModuleProto(google.protobuf.message.Message):
|
||||
"""The source of the optimization profile that this module contains."""
|
||||
compilation_event: tensorflow.compiler.xla.xla_data_pb2.CompilationEvent.ValueType
|
||||
"""The compilation event that triggered the use of the profile."""
|
||||
fingerprint: builtins.str
|
||||
"""The fingerprint of the unoptimized module this profile was applied to."""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -922,8 +1018,9 @@ class HloModuleProto(google.protobuf.message.Message):
|
||||
relative_speedup: builtins.float | None = ...,
|
||||
profile_source: tensorflow.compiler.xla.xla_data_pb2.ProfileSource.ValueType | None = ...,
|
||||
compilation_event: tensorflow.compiler.xla.xla_data_pb2.CompilationEvent.ValueType | None = ...,
|
||||
fingerprint: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["compilation_event", b"compilation_event", "profile_source", b"profile_source", "profile_type", b"profile_type", "relative_speedup", b"relative_speedup"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["compilation_event", b"compilation_event", "fingerprint", b"fingerprint", "profile_source", b"profile_source", "profile_type", b"profile_type", "relative_speedup", b"relative_speedup"]) -> None: ...
|
||||
|
||||
NAME_FIELD_NUMBER: builtins.int
|
||||
ENTRY_COMPUTATION_NAME_FIELD_NUMBER: builtins.int
|
||||
@@ -933,7 +1030,7 @@ class HloModuleProto(google.protobuf.message.Message):
|
||||
ID_FIELD_NUMBER: builtins.int
|
||||
SCHEDULE_FIELD_NUMBER: builtins.int
|
||||
INPUT_OUTPUT_ALIAS_FIELD_NUMBER: builtins.int
|
||||
DYNAMIC_PARAMETER_BINDING_FIELD_NUMBER: builtins.int
|
||||
BUFFER_DONOR_FIELD_NUMBER: builtins.int
|
||||
CROSS_PROGRAM_PREFETCHES_FIELD_NUMBER: builtins.int
|
||||
IS_DYNAMIC_FIELD_NUMBER: builtins.int
|
||||
SPMD_OUTPUT_SHARDING_FIELD_NUMBER: builtins.int
|
||||
@@ -941,6 +1038,8 @@ class HloModuleProto(google.protobuf.message.Message):
|
||||
USE_AUTO_SPMD_PARTITIONING_FIELD_NUMBER: builtins.int
|
||||
PROFILE_INFO_FIELD_NUMBER: builtins.int
|
||||
DEVICE_ASSIGNMENT_FIELD_NUMBER: builtins.int
|
||||
STACK_FRAME_INDEX_FIELD_NUMBER: builtins.int
|
||||
FRONTEND_ATTRIBUTES_FIELD_NUMBER: builtins.int
|
||||
name: builtins.str
|
||||
entry_computation_name: builtins.str
|
||||
entry_computation_id: builtins.int
|
||||
@@ -969,7 +1068,9 @@ class HloModuleProto(google.protobuf.message.Message):
|
||||
"""Describes alias information between inputs and outputs."""
|
||||
|
||||
@property
|
||||
def dynamic_parameter_binding(self) -> global___DynamicParameterBindingProto: ...
|
||||
def buffer_donor(self) -> global___HloBufferDonorProto:
|
||||
"""Describes the information of input buffer donors."""
|
||||
|
||||
@property
|
||||
def cross_program_prefetches(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___CrossProgramPrefetch]: ...
|
||||
@property
|
||||
@@ -984,6 +1085,14 @@ class HloModuleProto(google.protobuf.message.Message):
|
||||
def device_assignment(self) -> tensorflow.compiler.xla.xla_data_pb2.DeviceAssignmentProto:
|
||||
"""DeviceAssignment object information."""
|
||||
|
||||
@property
|
||||
def stack_frame_index(self) -> global___StackFrameIndexProto:
|
||||
"""Stack frames index."""
|
||||
|
||||
@property
|
||||
def frontend_attributes(self) -> tensorflow.compiler.xla.xla_data_pb2.FrontendAttributes:
|
||||
"""Frontend attributes to pass to the XLA backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -995,7 +1104,7 @@ class HloModuleProto(google.protobuf.message.Message):
|
||||
id: builtins.int | None = ...,
|
||||
schedule: global___HloScheduleProto | None = ...,
|
||||
input_output_alias: global___HloInputOutputAliasProto | None = ...,
|
||||
dynamic_parameter_binding: global___DynamicParameterBindingProto | None = ...,
|
||||
buffer_donor: global___HloBufferDonorProto | None = ...,
|
||||
cross_program_prefetches: collections.abc.Iterable[global___CrossProgramPrefetch] | None = ...,
|
||||
is_dynamic: builtins.bool | None = ...,
|
||||
spmd_output_sharding: tensorflow.compiler.xla.xla_data_pb2.OpSharding | None = ...,
|
||||
@@ -1003,9 +1112,11 @@ class HloModuleProto(google.protobuf.message.Message):
|
||||
use_auto_spmd_partitioning: builtins.bool | None = ...,
|
||||
profile_info: collections.abc.Iterable[global___HloModuleProto.ProfileInfo] | None = ...,
|
||||
device_assignment: tensorflow.compiler.xla.xla_data_pb2.DeviceAssignmentProto | None = ...,
|
||||
stack_frame_index: global___StackFrameIndexProto | None = ...,
|
||||
frontend_attributes: tensorflow.compiler.xla.xla_data_pb2.FrontendAttributes | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["device_assignment", b"device_assignment", "dynamic_parameter_binding", b"dynamic_parameter_binding", "host_program_shape", b"host_program_shape", "input_output_alias", b"input_output_alias", "schedule", b"schedule", "spmd_output_sharding", b"spmd_output_sharding"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["computations", b"computations", "cross_program_prefetches", b"cross_program_prefetches", "device_assignment", b"device_assignment", "dynamic_parameter_binding", b"dynamic_parameter_binding", "entry_computation_id", b"entry_computation_id", "entry_computation_name", b"entry_computation_name", "host_program_shape", b"host_program_shape", "id", b"id", "input_output_alias", b"input_output_alias", "is_dynamic", b"is_dynamic", "name", b"name", "profile_info", b"profile_info", "schedule", b"schedule", "spmd_output_sharding", b"spmd_output_sharding", "spmd_parameters_shardings", b"spmd_parameters_shardings", "use_auto_spmd_partitioning", b"use_auto_spmd_partitioning"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["buffer_donor", b"buffer_donor", "device_assignment", b"device_assignment", "frontend_attributes", b"frontend_attributes", "host_program_shape", b"host_program_shape", "input_output_alias", b"input_output_alias", "schedule", b"schedule", "spmd_output_sharding", b"spmd_output_sharding", "stack_frame_index", b"stack_frame_index"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["buffer_donor", b"buffer_donor", "computations", b"computations", "cross_program_prefetches", b"cross_program_prefetches", "device_assignment", b"device_assignment", "entry_computation_id", b"entry_computation_id", "entry_computation_name", b"entry_computation_name", "frontend_attributes", b"frontend_attributes", "host_program_shape", b"host_program_shape", "id", b"id", "input_output_alias", b"input_output_alias", "is_dynamic", b"is_dynamic", "name", b"name", "profile_info", b"profile_info", "schedule", b"schedule", "spmd_output_sharding", b"spmd_output_sharding", "spmd_parameters_shardings", b"spmd_parameters_shardings", "stack_frame_index", b"stack_frame_index", "use_auto_spmd_partitioning", b"use_auto_spmd_partitioning"]) -> None: ...
|
||||
|
||||
global___HloModuleProto = HloModuleProto
|
||||
|
||||
@@ -1435,6 +1546,7 @@ class HloPassMetadata(google.protobuf.message.Message):
|
||||
MODULE_GROUP_MODULE_IDS_FIELD_NUMBER: builtins.int
|
||||
START_TIMESTAMP_USEC_FIELD_NUMBER: builtins.int
|
||||
END_TIMESTAMP_USEC_FIELD_NUMBER: builtins.int
|
||||
CUSTOM_METADATA_FIELD_NUMBER: builtins.int
|
||||
pass_id: builtins.int
|
||||
"""For a given module, pass_id uniquely identifies a run of an HLO pass on
|
||||
that module. Note that a pass_id may not always refer to the same pass
|
||||
@@ -1470,6 +1582,10 @@ class HloPassMetadata(google.protobuf.message.Message):
|
||||
set as the ids of all the modules in the module group. Empty otherwise.
|
||||
"""
|
||||
|
||||
@property
|
||||
def custom_metadata(self) -> google.protobuf.any_pb2.Any:
|
||||
"""Custom metadata for the pass."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -1482,92 +1598,13 @@ class HloPassMetadata(google.protobuf.message.Message):
|
||||
module_group_module_ids: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
start_timestamp_usec: builtins.int | None = ...,
|
||||
end_timestamp_usec: builtins.int | None = ...,
|
||||
custom_metadata: google.protobuf.any_pb2.Any | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["dump_filenames", b"dump_filenames", "end_timestamp_usec", b"end_timestamp_usec", "module_changed", b"module_changed", "module_group_module_ids", b"module_group_module_ids", "module_id", b"module_id", "pass_id", b"pass_id", "pass_name", b"pass_name", "pipeline_name", b"pipeline_name", "start_timestamp_usec", b"start_timestamp_usec"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["custom_metadata", b"custom_metadata"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["custom_metadata", b"custom_metadata", "dump_filenames", b"dump_filenames", "end_timestamp_usec", b"end_timestamp_usec", "module_changed", b"module_changed", "module_group_module_ids", b"module_group_module_ids", "module_id", b"module_id", "pass_id", b"pass_id", "pass_name", b"pass_name", "pipeline_name", b"pipeline_name", "start_timestamp_usec", b"start_timestamp_usec"]) -> None: ...
|
||||
|
||||
global___HloPassMetadata = HloPassMetadata
|
||||
|
||||
@typing.final
|
||||
class EntryFunctionAttributes(google.protobuf.message.Message):
|
||||
"""Encodes attributes for an entry function."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@typing.final
|
||||
class ShapeIndex(google.protobuf.message.Message):
|
||||
"""Acts as the underlying container for an xla::ShapeIndex."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
INDICES_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def indices(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
indices: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["indices", b"indices"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class BufferParameterAttributes(google.protobuf.message.Message):
|
||||
"""Encodes attributes for a single buffer parameter."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
LMHLO_PARAMS_FIELD_NUMBER: builtins.int
|
||||
LMHLO_PARAMS_PRESENT_FIELD_NUMBER: builtins.int
|
||||
LMHLO_PARAM_SHAPE_INDEX_FIELD_NUMBER: builtins.int
|
||||
LMHLO_CONSTANT_NAME_FIELD_NUMBER: builtins.int
|
||||
LMHLO_MUST_ALIAS_FIELD_NUMBER: builtins.int
|
||||
LMHLO_OUTPUT_INDEX_FIELD_NUMBER: builtins.int
|
||||
lmhlo_params: builtins.int
|
||||
"""Represents an lmhlo.params function argument attribute."""
|
||||
lmhlo_params_present: builtins.bool
|
||||
"""TODO(hanbinyoon): Deprecate when optional fields are available in proto3
|
||||
(Protocol Buffers v3.15.0).
|
||||
"""
|
||||
lmhlo_constant_name: builtins.str
|
||||
"""Represents an lmhlo.constant_name function argument attribute."""
|
||||
lmhlo_must_alias: builtins.bool
|
||||
"""Represents an lmhlo.must_alias function argument attribute."""
|
||||
@property
|
||||
def lmhlo_param_shape_index(self) -> global___EntryFunctionAttributes.ShapeIndex:
|
||||
"""Represents an lmhlo.param_shape_index function argument attribute."""
|
||||
|
||||
@property
|
||||
def lmhlo_output_index(self) -> global___EntryFunctionAttributes.ShapeIndex:
|
||||
"""Represents an lmhlo.params function argument attribute."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
lmhlo_params: builtins.int | None = ...,
|
||||
lmhlo_params_present: builtins.bool | None = ...,
|
||||
lmhlo_param_shape_index: global___EntryFunctionAttributes.ShapeIndex | None = ...,
|
||||
lmhlo_constant_name: builtins.str | None = ...,
|
||||
lmhlo_must_alias: builtins.bool | None = ...,
|
||||
lmhlo_output_index: global___EntryFunctionAttributes.ShapeIndex | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["lmhlo_output_index", b"lmhlo_output_index", "lmhlo_param_shape_index", b"lmhlo_param_shape_index"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["lmhlo_constant_name", b"lmhlo_constant_name", "lmhlo_must_alias", b"lmhlo_must_alias", "lmhlo_output_index", b"lmhlo_output_index", "lmhlo_param_shape_index", b"lmhlo_param_shape_index", "lmhlo_params", b"lmhlo_params", "lmhlo_params_present", b"lmhlo_params_present"]) -> None: ...
|
||||
|
||||
BUFFERS_FIELD_NUMBER: builtins.int
|
||||
RESULT_XLA_SHAPE_FIELD_NUMBER: builtins.int
|
||||
result_xla_shape: builtins.str
|
||||
"""xla::Shape in string format."""
|
||||
@property
|
||||
def buffers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___EntryFunctionAttributes.BufferParameterAttributes]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
buffers: collections.abc.Iterable[global___EntryFunctionAttributes.BufferParameterAttributes] | None = ...,
|
||||
result_xla_shape: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["buffers", b"buffers", "result_xla_shape", b"result_xla_shape"]) -> None: ...
|
||||
|
||||
global___EntryFunctionAttributes = EntryFunctionAttributes
|
||||
|
||||
@typing.final
|
||||
class XlaRuntimeExecutableProto(google.protobuf.message.Message):
|
||||
"""Encodes the underlying Xla runtime executable compiled from the XLA module."""
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
Copyright 2018 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import typing
|
||||
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class HloProfilePrinterData(google.protobuf.message.Message):
|
||||
"""Describes how to pretty-print a profile counter array gathered for a specific
|
||||
HloModule.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@typing.final
|
||||
class HloInstructionInfo(google.protobuf.message.Message):
|
||||
"""Pretty-printer information about an HloInstruction."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
LONG_NAME_FIELD_NUMBER: builtins.int
|
||||
SHORT_NAME_FIELD_NUMBER: builtins.int
|
||||
CATEGORY_FIELD_NUMBER: builtins.int
|
||||
FLOP_COUNT_FIELD_NUMBER: builtins.int
|
||||
TRANSCENDENTAL_COUNT_FIELD_NUMBER: builtins.int
|
||||
BYTES_ACCESSED_FIELD_NUMBER: builtins.int
|
||||
OPTIMAL_SECONDS_FIELD_NUMBER: builtins.int
|
||||
PROFILE_INDEX_FIELD_NUMBER: builtins.int
|
||||
long_name: builtins.str
|
||||
short_name: builtins.str
|
||||
category: builtins.str
|
||||
flop_count: builtins.float
|
||||
"""Metrics computed by HloCostAnalysis."""
|
||||
transcendental_count: builtins.float
|
||||
bytes_accessed: builtins.int
|
||||
optimal_seconds: builtins.float
|
||||
profile_index: builtins.int
|
||||
"""The index into the profile counters array for the HloInstruction
|
||||
corresponding to this HloInstructionInfo.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
long_name: builtins.str | None = ...,
|
||||
short_name: builtins.str | None = ...,
|
||||
category: builtins.str | None = ...,
|
||||
flop_count: builtins.float | None = ...,
|
||||
transcendental_count: builtins.float | None = ...,
|
||||
bytes_accessed: builtins.int | None = ...,
|
||||
optimal_seconds: builtins.float | None = ...,
|
||||
profile_index: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["bytes_accessed", b"bytes_accessed", "category", b"category", "flop_count", b"flop_count", "long_name", b"long_name", "optimal_seconds", b"optimal_seconds", "profile_index", b"profile_index", "short_name", b"short_name", "transcendental_count", b"transcendental_count"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class HloComputationInfo(google.protobuf.message.Message):
|
||||
"""Pretty-printer information about an HloComputation."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
NAME_FIELD_NUMBER: builtins.int
|
||||
PROFILE_INDEX_FIELD_NUMBER: builtins.int
|
||||
INSTRUCTION_INFOS_FIELD_NUMBER: builtins.int
|
||||
name: builtins.str
|
||||
profile_index: builtins.int
|
||||
"""The index into the profile counters array for the HloComputation
|
||||
corresponding to this HloComputationInfo.
|
||||
"""
|
||||
@property
|
||||
def instruction_infos(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___HloProfilePrinterData.HloInstructionInfo]:
|
||||
"""HloInstructionInfos for every HloInstruction in the HloComputation for
|
||||
corresponding to this HloComputattionInfo.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: builtins.str | None = ...,
|
||||
profile_index: builtins.int | None = ...,
|
||||
instruction_infos: collections.abc.Iterable[global___HloProfilePrinterData.HloInstructionInfo] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["instruction_infos", b"instruction_infos", "name", b"name", "profile_index", b"profile_index"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class ExtraMetricsEntry(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
key: builtins.str
|
||||
value: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.str | None = ...,
|
||||
value: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
COMPUTATION_INFOS_FIELD_NUMBER: builtins.int
|
||||
PROFILE_COUNTERS_SIZE_FIELD_NUMBER: builtins.int
|
||||
EXTRA_METRICS_FIELD_NUMBER: builtins.int
|
||||
ENTRY_COMPUTATION_FIELD_NUMBER: builtins.int
|
||||
profile_counters_size: builtins.int
|
||||
"""The size of the profile counters array we will pretty-print."""
|
||||
entry_computation: builtins.str
|
||||
"""Name of the entry computation."""
|
||||
@property
|
||||
def computation_infos(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___HloProfilePrinterData.HloComputationInfo]:
|
||||
"""HloComputationInfos for every HloComputation in the HloModule."""
|
||||
|
||||
@property
|
||||
def extra_metrics(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.int]:
|
||||
"""Maps extra metric name to the index into the profile counters array."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
computation_infos: collections.abc.Iterable[global___HloProfilePrinterData.HloComputationInfo] | None = ...,
|
||||
profile_counters_size: builtins.int | None = ...,
|
||||
extra_metrics: collections.abc.Mapping[builtins.str, builtins.int] | None = ...,
|
||||
entry_computation: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["computation_infos", b"computation_infos", "entry_computation", b"entry_computation", "extra_metrics", b"extra_metrics", "profile_counters_size", b"profile_counters_size"]) -> None: ...
|
||||
|
||||
global___HloProfilePrinterData = HloProfilePrinterData
|
||||
@@ -4,11 +4,14 @@ isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import sys
|
||||
import typing
|
||||
|
||||
import google.protobuf.any_pb2
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.duration_pb2
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import google.protobuf.timestamp_pb2
|
||||
@@ -20,6 +23,44 @@ else:
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class PassMetrics(google.protobuf.message.Message):
|
||||
"""Defines pass specific metrics."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
MODULE_ID_FIELD_NUMBER: builtins.int
|
||||
PASS_NAME_FIELD_NUMBER: builtins.int
|
||||
PASS_DURATION_FIELD_NUMBER: builtins.int
|
||||
CUSTOM_METRICS_FIELD_NUMBER: builtins.int
|
||||
module_id: builtins.int
|
||||
"""Unique ID of the module on which the pass was run."""
|
||||
pass_name: builtins.str
|
||||
"""The name of the pass."""
|
||||
@property
|
||||
def pass_duration(self) -> google.protobuf.duration_pb2.Duration:
|
||||
"""Duration of the pass."""
|
||||
|
||||
@property
|
||||
def custom_metrics(self) -> google.protobuf.any_pb2.Any:
|
||||
"""Custom pass metrics. This is kept opaque, via `google.protobuf.Any`, in
|
||||
order to decouple pass agnostic compilation logs from possibly proprietary
|
||||
compiler passes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
module_id: builtins.int | None = ...,
|
||||
pass_name: builtins.str | None = ...,
|
||||
pass_duration: google.protobuf.duration_pb2.Duration | None = ...,
|
||||
custom_metrics: google.protobuf.any_pb2.Any | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["custom_metrics", b"custom_metrics", "pass_duration", b"pass_duration"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["custom_metrics", b"custom_metrics", "module_id", b"module_id", "pass_duration", b"pass_duration", "pass_name", b"pass_name"]) -> None: ...
|
||||
|
||||
global___PassMetrics = PassMetrics
|
||||
|
||||
@typing.final
|
||||
class CompilationLogEntry(google.protobuf.message.Message):
|
||||
"""Defines XLA compilation metrics."""
|
||||
@@ -51,10 +92,13 @@ class CompilationLogEntry(google.protobuf.message.Message):
|
||||
STAGE_FIELD_NUMBER: builtins.int
|
||||
DURATION_FIELD_NUMBER: builtins.int
|
||||
TASK_INDEX_FIELD_NUMBER: builtins.int
|
||||
PASS_METRICS_FIELD_NUMBER: builtins.int
|
||||
stage: global___CompilationLogEntry.CompilationStage.ValueType
|
||||
"""Compilation stage recorded by this log entry."""
|
||||
task_index: builtins.int
|
||||
"""Task index from which this log entry was recorded."""
|
||||
"""Task index from which this log entry was recorded or
|
||||
-1 if the task index could not be fetched.
|
||||
"""
|
||||
@property
|
||||
def timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp:
|
||||
"""Time when the event captured by this log entry occurred."""
|
||||
@@ -63,6 +107,10 @@ class CompilationLogEntry(google.protobuf.message.Message):
|
||||
def duration(self) -> google.protobuf.duration_pb2.Duration:
|
||||
"""Duration of the given compilation stage."""
|
||||
|
||||
@property
|
||||
def pass_metrics(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PassMetrics]:
|
||||
"""Pass specific metrics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -70,8 +118,9 @@ class CompilationLogEntry(google.protobuf.message.Message):
|
||||
stage: global___CompilationLogEntry.CompilationStage.ValueType | None = ...,
|
||||
duration: google.protobuf.duration_pb2.Duration | None = ...,
|
||||
task_index: builtins.int | None = ...,
|
||||
pass_metrics: collections.abc.Iterable[global___PassMetrics] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["duration", b"duration", "timestamp", b"timestamp"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["duration", b"duration", "stage", b"stage", "task_index", b"task_index", "timestamp", b"timestamp"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["duration", b"duration", "pass_metrics", b"pass_metrics", "stage", b"stage", "task_index", b"task_index", "timestamp", b"timestamp"]) -> None: ...
|
||||
|
||||
global___CompilationLogEntry = CompilationLogEntry
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
Copyright 2022 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import typing
|
||||
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.message
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class TestCompilationEnvironment1(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
SOME_FLAG_FIELD_NUMBER: builtins.int
|
||||
some_flag: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
some_flag: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["some_flag", b"some_flag"]) -> None: ...
|
||||
|
||||
global___TestCompilationEnvironment1 = TestCompilationEnvironment1
|
||||
|
||||
@typing.final
|
||||
class TestCompilationEnvironment2(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
SOME_OTHER_FLAG_FIELD_NUMBER: builtins.int
|
||||
some_other_flag: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
some_other_flag: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["some_other_flag", b"some_other_flag"]) -> None: ...
|
||||
|
||||
global___TestCompilationEnvironment2 = TestCompilationEnvironment2
|
||||
|
||||
@typing.final
|
||||
class TestCompilationEnvironment3(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
A_THIRD_FLAG_FIELD_NUMBER: builtins.int
|
||||
a_third_flag: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
a_third_flag: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["a_third_flag", b"a_third_flag"]) -> None: ...
|
||||
|
||||
global___TestCompilationEnvironment3 = TestCompilationEnvironment3
|
||||
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
Copyright 2023 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import typing
|
||||
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.duration_pb2
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
import tensorflow.compiler.xla.service.hlo_pb2
|
||||
import tensorflow.tsl.protobuf.status_pb2
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class CompilerPerfStats(google.protobuf.message.Message):
|
||||
"""Statistics on how long various parts of compilation took.
|
||||
Not all durations may be relevant for all producers of this message, in
|
||||
which irrelevant fields should simply be skipped.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
INIT_DURATION_FIELD_NUMBER: builtins.int
|
||||
HLO_VERIFICATION_DURATION_FIELD_NUMBER: builtins.int
|
||||
COMPILATION_PROLOGUE_DURATION_FIELD_NUMBER: builtins.int
|
||||
COMPILATION_DURATION_FIELD_NUMBER: builtins.int
|
||||
TOTAL_DURATION_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def init_duration(self) -> google.protobuf.duration_pb2.Duration:
|
||||
"""How long did it take to initialize the compiler?"""
|
||||
|
||||
@property
|
||||
def hlo_verification_duration(self) -> google.protobuf.duration_pb2.Duration:
|
||||
"""How long did it take to verify the HLO?"""
|
||||
|
||||
@property
|
||||
def compilation_prologue_duration(self) -> google.protobuf.duration_pb2.Duration:
|
||||
"""How long did it take to prepare for compilation after verification?"""
|
||||
|
||||
@property
|
||||
def compilation_duration(self) -> google.protobuf.duration_pb2.Duration:
|
||||
"""How long did it take to compile?"""
|
||||
|
||||
@property
|
||||
def total_duration(self) -> google.protobuf.duration_pb2.Duration:
|
||||
"""How long did everything take?"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
init_duration: google.protobuf.duration_pb2.Duration | None = ...,
|
||||
hlo_verification_duration: google.protobuf.duration_pb2.Duration | None = ...,
|
||||
compilation_prologue_duration: google.protobuf.duration_pb2.Duration | None = ...,
|
||||
compilation_duration: google.protobuf.duration_pb2.Duration | None = ...,
|
||||
total_duration: google.protobuf.duration_pb2.Duration | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["compilation_duration", b"compilation_duration", "compilation_prologue_duration", b"compilation_prologue_duration", "hlo_verification_duration", b"hlo_verification_duration", "init_duration", b"init_duration", "total_duration", b"total_duration"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["compilation_duration", b"compilation_duration", "compilation_prologue_duration", b"compilation_prologue_duration", "hlo_verification_duration", b"hlo_verification_duration", "init_duration", b"init_duration", "total_duration", b"total_duration"]) -> None: ...
|
||||
|
||||
global___CompilerPerfStats = CompilerPerfStats
|
||||
|
||||
@typing.final
|
||||
class CompilationResult(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@typing.final
|
||||
class CountersEntry(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
key: builtins.str
|
||||
value: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.str | None = ...,
|
||||
value: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
HLO_MODULE_FIELD_NUMBER: builtins.int
|
||||
PERF_STATS_FIELD_NUMBER: builtins.int
|
||||
STATUS_FIELD_NUMBER: builtins.int
|
||||
COUNTERS_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def hlo_module(self) -> tensorflow.compiler.xla.service.hlo_pb2.HloModuleProto:
|
||||
"""The compiled HLO. Only set when compilation succeeds."""
|
||||
|
||||
@property
|
||||
def perf_stats(self) -> global___CompilerPerfStats:
|
||||
"""Always set when compilation succeeds. May or may not be set when
|
||||
compilation fails.
|
||||
"""
|
||||
|
||||
@property
|
||||
def status(self) -> tensorflow.tsl.protobuf.status_pb2.StatusProto:
|
||||
"""Always set when compilation fails; never set when compilation succeeds."""
|
||||
|
||||
@property
|
||||
def counters(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.int]:
|
||||
"""Collects counters collected during compilation. Not every producer may
|
||||
include counter support at all or any particular counter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hlo_module: tensorflow.compiler.xla.service.hlo_pb2.HloModuleProto | None = ...,
|
||||
perf_stats: global___CompilerPerfStats | None = ...,
|
||||
status: tensorflow.tsl.protobuf.status_pb2.StatusProto | None = ...,
|
||||
counters: collections.abc.Mapping[builtins.str, builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["hlo_module", b"hlo_module", "perf_stats", b"perf_stats", "status", b"status"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["counters", b"counters", "hlo_module", b"hlo_module", "perf_stats", b"perf_stats", "status", b"status"]) -> None: ...
|
||||
|
||||
global___CompilationResult = CompilationResult
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
Copyright 2017 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -44,13 +44,15 @@ class _PrimitiveTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._
|
||||
"""Invalid primitive type to serve as default."""
|
||||
PRED: _PrimitiveType.ValueType # 1
|
||||
"""Predicates are two-state booleans."""
|
||||
S8: _PrimitiveType.ValueType # 2
|
||||
S4: _PrimitiveType.ValueType # 21
|
||||
"""Signed integral values of fixed width."""
|
||||
S8: _PrimitiveType.ValueType # 2
|
||||
S16: _PrimitiveType.ValueType # 3
|
||||
S32: _PrimitiveType.ValueType # 4
|
||||
S64: _PrimitiveType.ValueType # 5
|
||||
U8: _PrimitiveType.ValueType # 6
|
||||
U4: _PrimitiveType.ValueType # 22
|
||||
"""Unsigned integral values of fixed width."""
|
||||
U8: _PrimitiveType.ValueType # 6
|
||||
U16: _PrimitiveType.ValueType # 7
|
||||
U32: _PrimitiveType.ValueType # 8
|
||||
U64: _PrimitiveType.ValueType # 9
|
||||
@@ -78,11 +80,35 @@ class _PrimitiveTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._
|
||||
supported. NaN is represented when the exponent and mantissa bits are all
|
||||
1s. All other values are finite.
|
||||
|
||||
F8E4M3B11FNUZ has 4 exponent bits and 3 mantissa bits and a bias of 11. The
|
||||
"FNUZ" means only Finite and NaN values are supported; zero is unsigned.
|
||||
Unlike IEEE types, infinities are not supported. NaN is represented when
|
||||
the exponent and mantissa bits are all 0s with a sign bit of 1. All other
|
||||
values are finite.
|
||||
|
||||
Support for these dtypes is under development. They do not yet work
|
||||
properly in most cases.
|
||||
TODO(b/259609697): Fully support FP8.
|
||||
"""
|
||||
F8E4M3FN: _PrimitiveType.ValueType # 20
|
||||
F8E4M3B11FNUZ: _PrimitiveType.ValueType # 23
|
||||
F8E5M2FNUZ: _PrimitiveType.ValueType # 24
|
||||
"""FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915
|
||||
|
||||
F8E5M2FNUZ has 5 exponent bits and 2 mantissa bits.
|
||||
F8E4M3FNUZ has 4 exponent bits and 3 mantissa bits.
|
||||
|
||||
The "FNUZ" means only Finite and NaN values are supported; zero is
|
||||
unsigned. Unlike IEEE types, infinities are not supported. NaN is
|
||||
represented when the exponent and mantissa bits are all 0s with a sign bit
|
||||
of 1. All other values are finite.
|
||||
|
||||
These differences mean there's an additional exponent value available. To
|
||||
keep the same dynamic range as an IEEE-like FP8 type, the exponent is
|
||||
biased one more than would be expected given the number of exponent bits
|
||||
(8 for Float8E4M3FNUZ and 16 for Float8E5M2FNUZ).
|
||||
"""
|
||||
F8E4M3FNUZ: _PrimitiveType.ValueType # 25
|
||||
C64: _PrimitiveType.ValueType # 15
|
||||
"""Complex values of fixed width.
|
||||
Paired F32 (real, imag), as in std::complex<float>.
|
||||
@@ -123,13 +149,15 @@ PRIMITIVE_TYPE_INVALID: PrimitiveType.ValueType # 0
|
||||
"""Invalid primitive type to serve as default."""
|
||||
PRED: PrimitiveType.ValueType # 1
|
||||
"""Predicates are two-state booleans."""
|
||||
S8: PrimitiveType.ValueType # 2
|
||||
S4: PrimitiveType.ValueType # 21
|
||||
"""Signed integral values of fixed width."""
|
||||
S8: PrimitiveType.ValueType # 2
|
||||
S16: PrimitiveType.ValueType # 3
|
||||
S32: PrimitiveType.ValueType # 4
|
||||
S64: PrimitiveType.ValueType # 5
|
||||
U8: PrimitiveType.ValueType # 6
|
||||
U4: PrimitiveType.ValueType # 22
|
||||
"""Unsigned integral values of fixed width."""
|
||||
U8: PrimitiveType.ValueType # 6
|
||||
U16: PrimitiveType.ValueType # 7
|
||||
U32: PrimitiveType.ValueType # 8
|
||||
U64: PrimitiveType.ValueType # 9
|
||||
@@ -157,11 +185,35 @@ Finite and NaN values are supported. Unlike IEEE types, infinities are not
|
||||
supported. NaN is represented when the exponent and mantissa bits are all
|
||||
1s. All other values are finite.
|
||||
|
||||
F8E4M3B11FNUZ has 4 exponent bits and 3 mantissa bits and a bias of 11. The
|
||||
"FNUZ" means only Finite and NaN values are supported; zero is unsigned.
|
||||
Unlike IEEE types, infinities are not supported. NaN is represented when
|
||||
the exponent and mantissa bits are all 0s with a sign bit of 1. All other
|
||||
values are finite.
|
||||
|
||||
Support for these dtypes is under development. They do not yet work
|
||||
properly in most cases.
|
||||
TODO(b/259609697): Fully support FP8.
|
||||
"""
|
||||
F8E4M3FN: PrimitiveType.ValueType # 20
|
||||
F8E4M3B11FNUZ: PrimitiveType.ValueType # 23
|
||||
F8E5M2FNUZ: PrimitiveType.ValueType # 24
|
||||
"""FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915
|
||||
|
||||
F8E5M2FNUZ has 5 exponent bits and 2 mantissa bits.
|
||||
F8E4M3FNUZ has 4 exponent bits and 3 mantissa bits.
|
||||
|
||||
The "FNUZ" means only Finite and NaN values are supported; zero is
|
||||
unsigned. Unlike IEEE types, infinities are not supported. NaN is
|
||||
represented when the exponent and mantissa bits are all 0s with a sign bit
|
||||
of 1. All other values are finite.
|
||||
|
||||
These differences mean there's an additional exponent value available. To
|
||||
keep the same dynamic range as an IEEE-like FP8 type, the exponent is
|
||||
biased one more than would be expected given the number of exponent bits
|
||||
(8 for Float8E4M3FNUZ and 16 for Float8E5M2FNUZ).
|
||||
"""
|
||||
F8E4M3FNUZ: PrimitiveType.ValueType # 25
|
||||
C64: PrimitiveType.ValueType # 15
|
||||
"""Complex values of fixed width.
|
||||
Paired F32 (real, imag), as in std::complex<float>.
|
||||
@@ -205,6 +257,11 @@ class _DimLevelTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._E
|
||||
"""The corresponding dimension contains a single coordinate, no sibling
|
||||
elements for each parent.
|
||||
"""
|
||||
DIM_LOOSE_COMPRESSED: _DimLevelType.ValueType # 3
|
||||
"""The corresponding dimension is Compressed, but with potential trailing
|
||||
zeros, thus an extra upper bound (high) is used to exclude those zeros.
|
||||
E.g., indices = [1, 2, 0, 0, 3, 4, 0, 0], position = [(0, 2), (4, 6)].
|
||||
"""
|
||||
|
||||
class DimLevelType(_DimLevelType, metaclass=_DimLevelTypeEnumTypeWrapper):
|
||||
"""A DimLevelType indicates the encoding method for a dimension in an array.
|
||||
@@ -222,6 +279,11 @@ DIM_SINGLETON: DimLevelType.ValueType # 2
|
||||
"""The corresponding dimension contains a single coordinate, no sibling
|
||||
elements for each parent.
|
||||
"""
|
||||
DIM_LOOSE_COMPRESSED: DimLevelType.ValueType # 3
|
||||
"""The corresponding dimension is Compressed, but with potential trailing
|
||||
zeros, thus an extra upper bound (high) is used to exclude those zeros.
|
||||
E.g., indices = [1, 2, 0, 0, 3, 4, 0, 0], position = [(0, 2), (4, 6)].
|
||||
"""
|
||||
global___DimLevelType = DimLevelType
|
||||
|
||||
class _ProfileType:
|
||||
@@ -328,6 +390,23 @@ IRFFT: FftType.ValueType # 3
|
||||
"""Inverse real FFT; fft_length / 2 + 1 complex in,"""
|
||||
global___FftType = FftType
|
||||
|
||||
class _SparsityType:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _SparsityTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_SparsityType.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
SPARSITY_INVALID: _SparsityType.ValueType # 0
|
||||
SPARSITY_STRUCTURED_N_M: _SparsityType.ValueType # 1
|
||||
"""Structured N:M sparsity."""
|
||||
|
||||
class SparsityType(_SparsityType, metaclass=_SparsityTypeEnumTypeWrapper): ...
|
||||
|
||||
SPARSITY_INVALID: SparsityType.ValueType # 0
|
||||
SPARSITY_STRUCTURED_N_M: SparsityType.ValueType # 1
|
||||
"""Structured N:M sparsity."""
|
||||
global___SparsityType = SparsityType
|
||||
|
||||
class _RandomDistribution:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
@@ -474,11 +553,26 @@ class LayoutProto(google.protobuf.message.Message):
|
||||
DIM_ORDERED_FIELD_NUMBER: builtins.int
|
||||
MINOR_TO_MAJOR_FIELD_NUMBER: builtins.int
|
||||
TILES_FIELD_NUMBER: builtins.int
|
||||
TAIL_PADDING_ALIGNMENT_IN_ELEMENTS_FIELD_NUMBER: builtins.int
|
||||
ELEMENT_SIZE_IN_BITS_FIELD_NUMBER: builtins.int
|
||||
MEMORY_SPACE_FIELD_NUMBER: builtins.int
|
||||
INDEX_PRIMITIVE_TYPE_FIELD_NUMBER: builtins.int
|
||||
POINTER_PRIMITIVE_TYPE_FIELD_NUMBER: builtins.int
|
||||
PHYSICAL_SHAPE_FIELD_NUMBER: builtins.int
|
||||
DYNAMIC_SHAPE_METADATA_PREFIX_BYTES_FIELD_NUMBER: builtins.int
|
||||
tail_padding_alignment_in_elements: builtins.int
|
||||
"""The shape is padded at the end to multiple of, in terms of number of
|
||||
elements. This is useful when tiling does not bring the shape to certain
|
||||
desired granules. Tiling effectively pads/reshapes/transposes the shape
|
||||
to another shape. This field pads the total number of elements of that
|
||||
new shape to a multiple of certain number of elements. This is useful such
|
||||
as we want a layout which does not tile the data but still requires it to
|
||||
be padded to certain number of elements.
|
||||
"""
|
||||
element_size_in_bits: builtins.int
|
||||
"""(Optional) Bit size of each element. When unspecified or being 0, default
|
||||
to ShapeUtil::ByteSizeOfPrimitiveType.
|
||||
"""
|
||||
memory_space: builtins.int
|
||||
"""Memory space where this array resides. The integer field is interpreted in
|
||||
a backend-specific manner.
|
||||
@@ -546,6 +640,8 @@ class LayoutProto(google.protobuf.message.Message):
|
||||
dim_ordered: collections.abc.Iterable[builtins.bool] | None = ...,
|
||||
minor_to_major: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
tiles: collections.abc.Iterable[global___TileProto] | None = ...,
|
||||
tail_padding_alignment_in_elements: builtins.int | None = ...,
|
||||
element_size_in_bits: builtins.int | None = ...,
|
||||
memory_space: builtins.int | None = ...,
|
||||
index_primitive_type: global___PrimitiveType.ValueType | None = ...,
|
||||
pointer_primitive_type: global___PrimitiveType.ValueType | None = ...,
|
||||
@@ -553,7 +649,7 @@ class LayoutProto(google.protobuf.message.Message):
|
||||
dynamic_shape_metadata_prefix_bytes: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["physical_shape", b"physical_shape"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["dim_level_types", b"dim_level_types", "dim_ordered", b"dim_ordered", "dim_unique", b"dim_unique", "dynamic_shape_metadata_prefix_bytes", b"dynamic_shape_metadata_prefix_bytes", "index_primitive_type", b"index_primitive_type", "memory_space", b"memory_space", "minor_to_major", b"minor_to_major", "physical_shape", b"physical_shape", "pointer_primitive_type", b"pointer_primitive_type", "tiles", b"tiles"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["dim_level_types", b"dim_level_types", "dim_ordered", b"dim_ordered", "dim_unique", b"dim_unique", "dynamic_shape_metadata_prefix_bytes", b"dynamic_shape_metadata_prefix_bytes", "element_size_in_bits", b"element_size_in_bits", "index_primitive_type", b"index_primitive_type", "memory_space", b"memory_space", "minor_to_major", b"minor_to_major", "physical_shape", b"physical_shape", "pointer_primitive_type", b"pointer_primitive_type", "tail_padding_alignment_in_elements", b"tail_padding_alignment_in_elements", "tiles", b"tiles"]) -> None: ...
|
||||
|
||||
global___LayoutProto = LayoutProto
|
||||
|
||||
@@ -724,6 +820,9 @@ class OpMetadata(google.protobuf.message.Message):
|
||||
SIZE_OF_GENERATED_CODE_IN_BYTES_FIELD_NUMBER: builtins.int
|
||||
SIZE_OF_MEMORY_WORKING_SET_IN_BYTES_FIELD_NUMBER: builtins.int
|
||||
PROFILE_INFO_FIELD_NUMBER: builtins.int
|
||||
DEDUPLICATED_NAME_FIELD_NUMBER: builtins.int
|
||||
PRESERVE_LAYOUT_FIELD_NUMBER: builtins.int
|
||||
STACK_FRAME_ID_FIELD_NUMBER: builtins.int
|
||||
op_type: builtins.str
|
||||
"""The framework op name that generated this XLA op.
|
||||
|
||||
@@ -762,6 +861,20 @@ class OpMetadata(google.protobuf.message.Message):
|
||||
"""The size of the working set, i.e., the amount of memory, used by the
|
||||
instruction in a compiler-managed fast device memory.
|
||||
"""
|
||||
deduplicated_name: builtins.str
|
||||
"""Deduplicated HLO name for this op. In some cases, we can have multiple
|
||||
instructions (e.g. fusions) that are considered duplicates. We want to
|
||||
group them together under the same name so that we can group them together
|
||||
during analysis (e.g. HLO Op Profile tool in Xprof).
|
||||
E.g. If we have fusion.1, fusion.2, and fusion.3 marked as duplicates,
|
||||
fusion.2 and fusion.3 will have deduplicated_name = fusion.1
|
||||
"""
|
||||
preserve_layout: builtins.bool
|
||||
"""Whether to preserve the layout of the HLO op."""
|
||||
stack_frame_id: builtins.int
|
||||
"""1-based position of the frame in frames flat array.
|
||||
Ids are 1-based to keep 0 value as representation of non-set property.
|
||||
"""
|
||||
@property
|
||||
def profile_type(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[global___ProfileType.ValueType]:
|
||||
"""Deprecated, use [ProfileInfo][profile_type] instead."""
|
||||
@@ -783,9 +896,12 @@ class OpMetadata(google.protobuf.message.Message):
|
||||
size_of_generated_code_in_bytes: builtins.int | None = ...,
|
||||
size_of_memory_working_set_in_bytes: builtins.int | None = ...,
|
||||
profile_info: global___OpMetadata.ProfileInfo | None = ...,
|
||||
deduplicated_name: builtins.str | None = ...,
|
||||
preserve_layout: builtins.bool | None = ...,
|
||||
stack_frame_id: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["profile_info", b"profile_info"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["creation_pass_id", b"creation_pass_id", "logical_creation_pass_id", b"logical_creation_pass_id", "op_name", b"op_name", "op_type", b"op_type", "profile_info", b"profile_info", "profile_type", b"profile_type", "size_of_generated_code_in_bytes", b"size_of_generated_code_in_bytes", "size_of_memory_working_set_in_bytes", b"size_of_memory_working_set_in_bytes", "source_file", b"source_file", "source_line", b"source_line"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["creation_pass_id", b"creation_pass_id", "deduplicated_name", b"deduplicated_name", "logical_creation_pass_id", b"logical_creation_pass_id", "op_name", b"op_name", "op_type", b"op_type", "preserve_layout", b"preserve_layout", "profile_info", b"profile_info", "profile_type", b"profile_type", "size_of_generated_code_in_bytes", b"size_of_generated_code_in_bytes", "size_of_memory_working_set_in_bytes", b"size_of_memory_working_set_in_bytes", "source_file", b"source_file", "source_line", b"source_line", "stack_frame_id", b"stack_frame_id"]) -> None: ...
|
||||
|
||||
global___OpMetadata = OpMetadata
|
||||
|
||||
@@ -1023,6 +1139,8 @@ class LiteralProto(google.protobuf.message.Message):
|
||||
|
||||
SHAPE_FIELD_NUMBER: builtins.int
|
||||
PREDS_FIELD_NUMBER: builtins.int
|
||||
S4S_FIELD_NUMBER: builtins.int
|
||||
U4S_FIELD_NUMBER: builtins.int
|
||||
S8S_FIELD_NUMBER: builtins.int
|
||||
U8S_FIELD_NUMBER: builtins.int
|
||||
S32S_FIELD_NUMBER: builtins.int
|
||||
@@ -1040,7 +1158,12 @@ class LiteralProto(google.protobuf.message.Message):
|
||||
S16S_FIELD_NUMBER: builtins.int
|
||||
F8E5M2S_FIELD_NUMBER: builtins.int
|
||||
F8E4M3FNS_FIELD_NUMBER: builtins.int
|
||||
F8E4M3B11FNUZS_FIELD_NUMBER: builtins.int
|
||||
F8E5M2FNUZS_FIELD_NUMBER: builtins.int
|
||||
F8E4M3FNUZS_FIELD_NUMBER: builtins.int
|
||||
SPARSE_INDICES_FIELD_NUMBER: builtins.int
|
||||
s4s: builtins.bytes
|
||||
u4s: builtins.bytes
|
||||
s8s: builtins.bytes
|
||||
u8s: builtins.bytes
|
||||
f16s: builtins.bytes
|
||||
@@ -1050,6 +1173,9 @@ class LiteralProto(google.protobuf.message.Message):
|
||||
s16s: builtins.bytes
|
||||
f8e5m2s: builtins.bytes
|
||||
f8e4m3fns: builtins.bytes
|
||||
f8e4m3b11fnuzs: builtins.bytes
|
||||
f8e5m2fnuzs: builtins.bytes
|
||||
f8e4m3fnuzs: builtins.bytes
|
||||
@property
|
||||
def shape(self) -> global___ShapeProto: ...
|
||||
@property
|
||||
@@ -1078,13 +1204,15 @@ class LiteralProto(google.protobuf.message.Message):
|
||||
def tuple_literals(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___LiteralProto]: ...
|
||||
@property
|
||||
def sparse_indices(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""Next = 21"""
|
||||
"""Next = 26"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
shape: global___ShapeProto | None = ...,
|
||||
preds: collections.abc.Iterable[builtins.bool] | None = ...,
|
||||
s4s: builtins.bytes | None = ...,
|
||||
u4s: builtins.bytes | None = ...,
|
||||
s8s: builtins.bytes | None = ...,
|
||||
u8s: builtins.bytes | None = ...,
|
||||
s32s: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
@@ -1102,10 +1230,13 @@ class LiteralProto(google.protobuf.message.Message):
|
||||
s16s: builtins.bytes | None = ...,
|
||||
f8e5m2s: builtins.bytes | None = ...,
|
||||
f8e4m3fns: builtins.bytes | None = ...,
|
||||
f8e4m3b11fnuzs: builtins.bytes | None = ...,
|
||||
f8e5m2fnuzs: builtins.bytes | None = ...,
|
||||
f8e4m3fnuzs: builtins.bytes | None = ...,
|
||||
sparse_indices: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["shape", b"shape"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["bf16s", b"bf16s", "c128s", b"c128s", "c64s", b"c64s", "f16s", b"f16s", "f32s", b"f32s", "f64s", b"f64s", "f8e4m3fns", b"f8e4m3fns", "f8e5m2s", b"f8e5m2s", "preds", b"preds", "s16s", b"s16s", "s32s", b"s32s", "s64s", b"s64s", "s8s", b"s8s", "shape", b"shape", "sparse_indices", b"sparse_indices", "tuple_literals", b"tuple_literals", "u16s", b"u16s", "u32s", b"u32s", "u64s", b"u64s", "u8s", b"u8s"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["bf16s", b"bf16s", "c128s", b"c128s", "c64s", b"c64s", "f16s", b"f16s", "f32s", b"f32s", "f64s", b"f64s", "f8e4m3b11fnuzs", b"f8e4m3b11fnuzs", "f8e4m3fns", b"f8e4m3fns", "f8e4m3fnuzs", b"f8e4m3fnuzs", "f8e5m2fnuzs", b"f8e5m2fnuzs", "f8e5m2s", b"f8e5m2s", "preds", b"preds", "s16s", b"s16s", "s32s", b"s32s", "s4s", b"s4s", "s64s", b"s64s", "s8s", b"s8s", "shape", b"shape", "sparse_indices", b"sparse_indices", "tuple_literals", b"tuple_literals", "u16s", b"u16s", "u32s", b"u32s", "u4s", b"u4s", "u64s", b"u64s", "u8s", b"u8s"]) -> None: ...
|
||||
|
||||
global___LiteralProto = LiteralProto
|
||||
|
||||
@@ -1392,6 +1523,44 @@ class DotDimensionNumbers(google.protobuf.message.Message):
|
||||
|
||||
global___DotDimensionNumbers = DotDimensionNumbers
|
||||
|
||||
@typing.final
|
||||
class SparsityDescriptor(google.protobuf.message.Message):
|
||||
"""Contains sparsity metadata for a sparse dot operation.
|
||||
The only supported type atm is structured 2:4 sparsity, which is natively
|
||||
supported on NVidia GPUs.
|
||||
Restrictions:
|
||||
- only one operand of the dot operation may be sparse;
|
||||
- only the contracting dimension may be sparse.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
INDEX_FIELD_NUMBER: builtins.int
|
||||
DIMENSION_FIELD_NUMBER: builtins.int
|
||||
N_FIELD_NUMBER: builtins.int
|
||||
M_FIELD_NUMBER: builtins.int
|
||||
type: global___SparsityType.ValueType
|
||||
index: builtins.int
|
||||
"""Sparse operand index (0 or 1)."""
|
||||
dimension: builtins.int
|
||||
"""Sparse dimension number."""
|
||||
n: builtins.int
|
||||
"""Structured N:M sparsity (N < M)."""
|
||||
m: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
type: global___SparsityType.ValueType | None = ...,
|
||||
index: builtins.int | None = ...,
|
||||
dimension: builtins.int | None = ...,
|
||||
n: builtins.int | None = ...,
|
||||
m: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["dimension", b"dimension", "index", b"index", "m", b"m", "n", b"n", "type", b"type"]) -> None: ...
|
||||
|
||||
global___SparsityDescriptor = SparsityDescriptor
|
||||
|
||||
@typing.final
|
||||
class TriangularSolveOptions(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
@@ -1462,6 +1631,23 @@ class CholeskyOptions(google.protobuf.message.Message):
|
||||
|
||||
global___CholeskyOptions = CholeskyOptions
|
||||
|
||||
@typing.final
|
||||
class SortOptions(google.protobuf.message.Message):
|
||||
"""Attributes of the sort custom call (cub::DeviceRadixSort)."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
DESCENDING_FIELD_NUMBER: builtins.int
|
||||
descending: builtins.bool
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
descending: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["descending", b"descending"]) -> None: ...
|
||||
|
||||
global___SortOptions = SortOptions
|
||||
|
||||
@typing.final
|
||||
class FrontendAttributes(google.protobuf.message.Message):
|
||||
"""Generic map of attributes used to pass hints / configuration options from
|
||||
@@ -1498,6 +1684,54 @@ class FrontendAttributes(google.protobuf.message.Message):
|
||||
|
||||
global___FrontendAttributes = FrontendAttributes
|
||||
|
||||
@typing.final
|
||||
class Statistic(google.protobuf.message.Message):
|
||||
"""Represents a single statistic to track."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
STAT_NAME_FIELD_NUMBER: builtins.int
|
||||
STAT_VAL_FIELD_NUMBER: builtins.int
|
||||
stat_name: builtins.str
|
||||
"""Must be a single word consisting of any alphanumeric characters"""
|
||||
stat_val: builtins.float
|
||||
"""Must be within a range of [0, 100], in order for the graph dumper to
|
||||
properly render the statistic onto the graph.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
stat_name: builtins.str | None = ...,
|
||||
stat_val: builtins.float | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["stat_name", b"stat_name", "stat_val", b"stat_val"]) -> None: ...
|
||||
|
||||
global___Statistic = Statistic
|
||||
|
||||
@typing.final
|
||||
class StatisticsViz(google.protobuf.message.Message):
|
||||
"""Represents the information needed to visualize propagation statistics when
|
||||
rendering an HLO graph. This includes an array of statistics as well as the
|
||||
index of the statistic to render.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
STAT_INDEX_TO_VISUALIZE_FIELD_NUMBER: builtins.int
|
||||
STATISTICS_FIELD_NUMBER: builtins.int
|
||||
stat_index_to_visualize: builtins.int
|
||||
@property
|
||||
def statistics(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Statistic]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
stat_index_to_visualize: builtins.int | None = ...,
|
||||
statistics: collections.abc.Iterable[global___Statistic] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["stat_index_to_visualize", b"stat_index_to_visualize", "statistics", b"statistics"]) -> None: ...
|
||||
|
||||
global___StatisticsViz = StatisticsViz
|
||||
|
||||
@typing.final
|
||||
class OpSharding(google.protobuf.message.Message):
|
||||
"""LINT.IfChange"""
|
||||
@@ -1524,6 +1758,10 @@ class OpSharding(google.protobuf.message.Message):
|
||||
"""This op is manually sharded: the shapes are already partitioned and the
|
||||
partitioner should not change this op.
|
||||
"""
|
||||
UNKNOWN: OpSharding._Type.ValueType # 5
|
||||
"""This sharding is a placeholder sharding with lowest precedence, it can be
|
||||
overwriten by any other shardings.
|
||||
"""
|
||||
|
||||
class Type(_Type, metaclass=_TypeEnumTypeWrapper): ...
|
||||
REPLICATED: OpSharding.Type.ValueType # 0
|
||||
@@ -1540,6 +1778,43 @@ class OpSharding(google.protobuf.message.Message):
|
||||
"""This op is manually sharded: the shapes are already partitioned and the
|
||||
partitioner should not change this op.
|
||||
"""
|
||||
UNKNOWN: OpSharding.Type.ValueType # 5
|
||||
"""This sharding is a placeholder sharding with lowest precedence, it can be
|
||||
overwriten by any other shardings.
|
||||
"""
|
||||
|
||||
class _ShardGroupType:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _ShardGroupTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[OpSharding._ShardGroupType.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
AS: OpSharding._ShardGroupType.ValueType # 0
|
||||
"""This op will be sharded exactly the same as the other op. (hard
|
||||
restriction)
|
||||
"""
|
||||
LIKE: OpSharding._ShardGroupType.ValueType # 1
|
||||
"""This op will try to allow sharding propagation within the same group even
|
||||
there is no data dependencies among them, but there is no guarantee that
|
||||
the final shardings within the same group will be exactly the same. (soft
|
||||
restriction)
|
||||
"""
|
||||
|
||||
class ShardGroupType(_ShardGroupType, metaclass=_ShardGroupTypeEnumTypeWrapper):
|
||||
"""Used to decide whether this op is to be sharded like some other ops, or to
|
||||
which other ops will be sharded like.
|
||||
"""
|
||||
|
||||
AS: OpSharding.ShardGroupType.ValueType # 0
|
||||
"""This op will be sharded exactly the same as the other op. (hard
|
||||
restriction)
|
||||
"""
|
||||
LIKE: OpSharding.ShardGroupType.ValueType # 1
|
||||
"""This op will try to allow sharding propagation within the same group even
|
||||
there is no data dependencies among them, but there is no guarantee that
|
||||
the final shardings within the same group will be exactly the same. (soft
|
||||
restriction)
|
||||
"""
|
||||
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
TILE_SHAPE_FIELD_NUMBER: builtins.int
|
||||
@@ -1549,12 +1824,22 @@ class OpSharding(google.protobuf.message.Message):
|
||||
REPLICATE_ON_LAST_TILE_DIM_FIELD_NUMBER: builtins.int
|
||||
METADATA_FIELD_NUMBER: builtins.int
|
||||
LAST_TILE_DIMS_FIELD_NUMBER: builtins.int
|
||||
IOTA_RESHAPE_DIMS_FIELD_NUMBER: builtins.int
|
||||
IOTA_TRANSPOSE_PERM_FIELD_NUMBER: builtins.int
|
||||
IS_SHARD_GROUP_FIELD_NUMBER: builtins.int
|
||||
SHARD_GROUP_ID_FIELD_NUMBER: builtins.int
|
||||
SHARD_GROUP_TYPE_FIELD_NUMBER: builtins.int
|
||||
type: global___OpSharding.Type.ValueType
|
||||
replicate_on_last_tile_dim: builtins.bool
|
||||
"""Only used for OTHER type. If true, data is sharded according to other
|
||||
dimensions of tile_assignment(), but replicated across devices along the
|
||||
last dimension. (Experimental)
|
||||
"""
|
||||
is_shard_group: builtins.bool
|
||||
"""This field decides whether this op is in a shard group."""
|
||||
shard_group_id: builtins.int
|
||||
"""This field is used to store the unique id of the shard group."""
|
||||
shard_group_type: global___OpSharding.ShardGroupType.ValueType
|
||||
@property
|
||||
def tile_shape(self) -> global___ShapeProto:
|
||||
"""The shape of the sharded tile."""
|
||||
@@ -1570,6 +1855,7 @@ class OpSharding(google.protobuf.message.Message):
|
||||
def tile_assignment_devices(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""Flattened list of device IDs. The order of flattening is the same as used
|
||||
by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
|
||||
Only one of tile_assignment_devices and iota_dimensions shall be non-empty.
|
||||
"""
|
||||
|
||||
@property
|
||||
@@ -1600,6 +1886,19 @@ class OpSharding(google.protobuf.message.Message):
|
||||
unreduced sharding type respectively.
|
||||
"""
|
||||
|
||||
@property
|
||||
def iota_reshape_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""Dimensions used to reshape the 1D iota array of device IDs.
|
||||
Only one of tile_assignment_devices and iota_reshape_dims shall be
|
||||
non-empty.
|
||||
"""
|
||||
|
||||
@property
|
||||
def iota_transpose_perm(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
|
||||
"""Dimension permutations to transposed the iota array reshaped to
|
||||
iota_reshape_dims. This must have the same size as iota_reshape_dims.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -1611,9 +1910,14 @@ class OpSharding(google.protobuf.message.Message):
|
||||
replicate_on_last_tile_dim: builtins.bool | None = ...,
|
||||
metadata: collections.abc.Iterable[global___OpMetadata] | None = ...,
|
||||
last_tile_dims: collections.abc.Iterable[global___OpSharding.Type.ValueType] | None = ...,
|
||||
iota_reshape_dims: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
iota_transpose_perm: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
is_shard_group: builtins.bool | None = ...,
|
||||
shard_group_id: builtins.int | None = ...,
|
||||
shard_group_type: global___OpSharding.ShardGroupType.ValueType | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["tile_shape", b"tile_shape"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["last_tile_dims", b"last_tile_dims", "metadata", b"metadata", "replicate_on_last_tile_dim", b"replicate_on_last_tile_dim", "tile_assignment_devices", b"tile_assignment_devices", "tile_assignment_dimensions", b"tile_assignment_dimensions", "tile_shape", b"tile_shape", "tuple_shardings", b"tuple_shardings", "type", b"type"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["iota_reshape_dims", b"iota_reshape_dims", "iota_transpose_perm", b"iota_transpose_perm", "is_shard_group", b"is_shard_group", "last_tile_dims", b"last_tile_dims", "metadata", b"metadata", "replicate_on_last_tile_dim", b"replicate_on_last_tile_dim", "shard_group_id", b"shard_group_id", "shard_group_type", b"shard_group_type", "tile_assignment_devices", b"tile_assignment_devices", "tile_assignment_dimensions", b"tile_assignment_dimensions", "tile_shape", b"tile_shape", "tuple_shardings", b"tuple_shardings", "type", b"type"]) -> None: ...
|
||||
|
||||
global___OpSharding = OpSharding
|
||||
|
||||
|
||||
2245
stubs/tensorflow/tensorflow/compiler/xla/xla_pb2.pyi
Normal file
2245
stubs/tensorflow/tensorflow/compiler/xla/xla_pb2.pyi
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import typing
|
||||
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
import tensorflow.core.framework.full_type_pb2
|
||||
import tensorflow.core.framework.tensor_shape_pb2
|
||||
import tensorflow.core.framework.types_pb2
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class CppShapeInferenceResult(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@typing.final
|
||||
class HandleShapeAndType(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
SHAPE_FIELD_NUMBER: builtins.int
|
||||
DTYPE_FIELD_NUMBER: builtins.int
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
dtype: tensorflow.core.framework.types_pb2.DataType.ValueType
|
||||
@property
|
||||
def shape(self) -> tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto: ...
|
||||
@property
|
||||
def type(self) -> tensorflow.core.framework.full_type_pb2.FullTypeDef: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
shape: tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto | None = ...,
|
||||
dtype: tensorflow.core.framework.types_pb2.DataType.ValueType | None = ...,
|
||||
type: tensorflow.core.framework.full_type_pb2.FullTypeDef | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["shape", b"shape", "type", b"type"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["dtype", b"dtype", "shape", b"shape", "type", b"type"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class HandleData(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
IS_SET_FIELD_NUMBER: builtins.int
|
||||
SHAPE_AND_TYPE_FIELD_NUMBER: builtins.int
|
||||
is_set: builtins.bool
|
||||
@property
|
||||
def shape_and_type(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___CppShapeInferenceResult.HandleShapeAndType]:
|
||||
"""Only valid if <is_set>."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
is_set: builtins.bool | None = ...,
|
||||
shape_and_type: collections.abc.Iterable[global___CppShapeInferenceResult.HandleShapeAndType] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["is_set", b"is_set", "shape_and_type", b"shape_and_type"]) -> None: ...
|
||||
|
||||
SHAPE_FIELD_NUMBER: builtins.int
|
||||
HANDLE_DATA_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def shape(self) -> tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto: ...
|
||||
@property
|
||||
def handle_data(self) -> global___CppShapeInferenceResult.HandleData: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
shape: tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto | None = ...,
|
||||
handle_data: global___CppShapeInferenceResult.HandleData | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["handle_data", b"handle_data", "shape", b"shape"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["handle_data", b"handle_data", "shape", b"shape"]) -> None: ...
|
||||
|
||||
global___CppShapeInferenceResult = CppShapeInferenceResult
|
||||
|
||||
@typing.final
|
||||
class CppShapeInferenceInputsNeeded(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
INPUT_TENSORS_NEEDED_FIELD_NUMBER: builtins.int
|
||||
INPUT_TENSORS_AS_SHAPES_NEEDED_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def input_tensors_needed(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
|
||||
@property
|
||||
def input_tensors_as_shapes_needed(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_tensors_needed: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
input_tensors_as_shapes_needed: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["input_tensors_as_shapes_needed", b"input_tensors_as_shapes_needed", "input_tensors_needed", b"input_tensors_needed"]) -> None: ...
|
||||
|
||||
global___CppShapeInferenceInputsNeeded = CppShapeInferenceInputsNeeded
|
||||
@@ -4,10 +4,12 @@ isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import sys
|
||||
import typing
|
||||
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import tensorflow.core.framework.model_pb2
|
||||
@@ -198,7 +200,7 @@ global___DistributeOptions = DistributeOptions
|
||||
|
||||
@typing.final
|
||||
class OptimizationOptions(google.protobuf.message.Message):
|
||||
"""next: 20"""
|
||||
"""next: 22"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@@ -213,6 +215,7 @@ class OptimizationOptions(google.protobuf.message.Message):
|
||||
SHUFFLE_AND_REPEAT_FUSION_FIELD_NUMBER: builtins.int
|
||||
FILTER_PARALLELIZATION_FIELD_NUMBER: builtins.int
|
||||
INJECT_PREFETCH_FIELD_NUMBER: builtins.int
|
||||
SEQ_INTERLEAVE_PREFETCH_FIELD_NUMBER: builtins.int
|
||||
apply_default_optimizations: builtins.bool
|
||||
filter_fusion: builtins.bool
|
||||
map_and_batch_fusion: builtins.bool
|
||||
@@ -224,6 +227,7 @@ class OptimizationOptions(google.protobuf.message.Message):
|
||||
shuffle_and_repeat_fusion: builtins.bool
|
||||
filter_parallelization: builtins.bool
|
||||
inject_prefetch: builtins.bool
|
||||
seq_interleave_prefetch: builtins.bool
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -238,9 +242,10 @@ class OptimizationOptions(google.protobuf.message.Message):
|
||||
shuffle_and_repeat_fusion: builtins.bool | None = ...,
|
||||
filter_parallelization: builtins.bool | None = ...,
|
||||
inject_prefetch: builtins.bool | None = ...,
|
||||
seq_interleave_prefetch: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["apply_default_optimizations", b"apply_default_optimizations", "filter_fusion", b"filter_fusion", "filter_parallelization", b"filter_parallelization", "inject_prefetch", b"inject_prefetch", "map_and_batch_fusion", b"map_and_batch_fusion", "map_and_filter_fusion", b"map_and_filter_fusion", "map_fusion", b"map_fusion", "map_parallelization", b"map_parallelization", "noop_elimination", b"noop_elimination", "optional_apply_default_optimizations", b"optional_apply_default_optimizations", "optional_filter_fusion", b"optional_filter_fusion", "optional_filter_parallelization", b"optional_filter_parallelization", "optional_inject_prefetch", b"optional_inject_prefetch", "optional_map_and_batch_fusion", b"optional_map_and_batch_fusion", "optional_map_and_filter_fusion", b"optional_map_and_filter_fusion", "optional_map_fusion", b"optional_map_fusion", "optional_map_parallelization", b"optional_map_parallelization", "optional_noop_elimination", b"optional_noop_elimination", "optional_parallel_batch", b"optional_parallel_batch", "optional_shuffle_and_repeat_fusion", b"optional_shuffle_and_repeat_fusion", "parallel_batch", b"parallel_batch", "shuffle_and_repeat_fusion", b"shuffle_and_repeat_fusion"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["apply_default_optimizations", b"apply_default_optimizations", "filter_fusion", b"filter_fusion", "filter_parallelization", b"filter_parallelization", "inject_prefetch", b"inject_prefetch", "map_and_batch_fusion", b"map_and_batch_fusion", "map_and_filter_fusion", b"map_and_filter_fusion", "map_fusion", b"map_fusion", "map_parallelization", b"map_parallelization", "noop_elimination", b"noop_elimination", "optional_apply_default_optimizations", b"optional_apply_default_optimizations", "optional_filter_fusion", b"optional_filter_fusion", "optional_filter_parallelization", b"optional_filter_parallelization", "optional_inject_prefetch", b"optional_inject_prefetch", "optional_map_and_batch_fusion", b"optional_map_and_batch_fusion", "optional_map_and_filter_fusion", b"optional_map_and_filter_fusion", "optional_map_fusion", b"optional_map_fusion", "optional_map_parallelization", b"optional_map_parallelization", "optional_noop_elimination", b"optional_noop_elimination", "optional_parallel_batch", b"optional_parallel_batch", "optional_shuffle_and_repeat_fusion", b"optional_shuffle_and_repeat_fusion", "parallel_batch", b"parallel_batch", "shuffle_and_repeat_fusion", b"shuffle_and_repeat_fusion"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["apply_default_optimizations", b"apply_default_optimizations", "filter_fusion", b"filter_fusion", "filter_parallelization", b"filter_parallelization", "inject_prefetch", b"inject_prefetch", "map_and_batch_fusion", b"map_and_batch_fusion", "map_and_filter_fusion", b"map_and_filter_fusion", "map_fusion", b"map_fusion", "map_parallelization", b"map_parallelization", "noop_elimination", b"noop_elimination", "optional_apply_default_optimizations", b"optional_apply_default_optimizations", "optional_filter_fusion", b"optional_filter_fusion", "optional_filter_parallelization", b"optional_filter_parallelization", "optional_inject_prefetch", b"optional_inject_prefetch", "optional_map_and_batch_fusion", b"optional_map_and_batch_fusion", "optional_map_and_filter_fusion", b"optional_map_and_filter_fusion", "optional_map_fusion", b"optional_map_fusion", "optional_map_parallelization", b"optional_map_parallelization", "optional_noop_elimination", b"optional_noop_elimination", "optional_parallel_batch", b"optional_parallel_batch", "optional_seq_interleave_prefetch", b"optional_seq_interleave_prefetch", "optional_shuffle_and_repeat_fusion", b"optional_shuffle_and_repeat_fusion", "parallel_batch", b"parallel_batch", "seq_interleave_prefetch", b"seq_interleave_prefetch", "shuffle_and_repeat_fusion", b"shuffle_and_repeat_fusion"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["apply_default_optimizations", b"apply_default_optimizations", "filter_fusion", b"filter_fusion", "filter_parallelization", b"filter_parallelization", "inject_prefetch", b"inject_prefetch", "map_and_batch_fusion", b"map_and_batch_fusion", "map_and_filter_fusion", b"map_and_filter_fusion", "map_fusion", b"map_fusion", "map_parallelization", b"map_parallelization", "noop_elimination", b"noop_elimination", "optional_apply_default_optimizations", b"optional_apply_default_optimizations", "optional_filter_fusion", b"optional_filter_fusion", "optional_filter_parallelization", b"optional_filter_parallelization", "optional_inject_prefetch", b"optional_inject_prefetch", "optional_map_and_batch_fusion", b"optional_map_and_batch_fusion", "optional_map_and_filter_fusion", b"optional_map_and_filter_fusion", "optional_map_fusion", b"optional_map_fusion", "optional_map_parallelization", b"optional_map_parallelization", "optional_noop_elimination", b"optional_noop_elimination", "optional_parallel_batch", b"optional_parallel_batch", "optional_seq_interleave_prefetch", b"optional_seq_interleave_prefetch", "optional_shuffle_and_repeat_fusion", b"optional_shuffle_and_repeat_fusion", "parallel_batch", b"parallel_batch", "seq_interleave_prefetch", b"seq_interleave_prefetch", "shuffle_and_repeat_fusion", b"shuffle_and_repeat_fusion"]) -> None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_apply_default_optimizations", b"optional_apply_default_optimizations"]) -> typing.Literal["apply_default_optimizations"] | None: ...
|
||||
@typing.overload
|
||||
@@ -262,6 +267,8 @@ class OptimizationOptions(google.protobuf.message.Message):
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_parallel_batch", b"optional_parallel_batch"]) -> typing.Literal["parallel_batch"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_seq_interleave_prefetch", b"optional_seq_interleave_prefetch"]) -> typing.Literal["seq_interleave_prefetch"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_shuffle_and_repeat_fusion", b"optional_shuffle_and_repeat_fusion"]) -> typing.Literal["shuffle_and_repeat_fusion"] | None: ...
|
||||
|
||||
global___OptimizationOptions = OptimizationOptions
|
||||
@@ -296,11 +303,13 @@ class Options(google.protobuf.message.Message):
|
||||
"""Message stored with Dataset objects to control how datasets are processed and
|
||||
optimized.
|
||||
|
||||
next: 9
|
||||
next: 12
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
DATASET_NAME_FIELD_NUMBER: builtins.int
|
||||
FRAMEWORK_TYPE_FIELD_NUMBER: builtins.int
|
||||
DETERMINISTIC_FIELD_NUMBER: builtins.int
|
||||
AUTOTUNE_OPTIONS_FIELD_NUMBER: builtins.int
|
||||
DISTRIBUTE_OPTIONS_FIELD_NUMBER: builtins.int
|
||||
@@ -309,13 +318,20 @@ class Options(google.protobuf.message.Message):
|
||||
THREADING_OPTIONS_FIELD_NUMBER: builtins.int
|
||||
EXTERNAL_STATE_POLICY_FIELD_NUMBER: builtins.int
|
||||
SYMBOLIC_CHECKPOINT_FIELD_NUMBER: builtins.int
|
||||
WARM_START_FIELD_NUMBER: builtins.int
|
||||
dataset_name: builtins.str
|
||||
deterministic: builtins.bool
|
||||
slack: builtins.bool
|
||||
external_state_policy: global___ExternalStatePolicy.ValueType
|
||||
symbolic_checkpoint: builtins.bool
|
||||
warm_start: builtins.bool
|
||||
@property
|
||||
def framework_type(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
|
||||
"""List of frameworks used to generate this dataset."""
|
||||
|
||||
@property
|
||||
def autotune_options(self) -> global___AutotuneOptions:
|
||||
"""The distribution strategy options associated with the dataset."""
|
||||
"""The autotune options associated with the dataset."""
|
||||
|
||||
@property
|
||||
def distribute_options(self) -> global___DistributeOptions:
|
||||
@@ -332,6 +348,8 @@ class Options(google.protobuf.message.Message):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dataset_name: builtins.str | None = ...,
|
||||
framework_type: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
deterministic: builtins.bool | None = ...,
|
||||
autotune_options: global___AutotuneOptions | None = ...,
|
||||
distribute_options: global___DistributeOptions | None = ...,
|
||||
@@ -340,9 +358,12 @@ class Options(google.protobuf.message.Message):
|
||||
threading_options: global___ThreadingOptions | None = ...,
|
||||
external_state_policy: global___ExternalStatePolicy.ValueType | None = ...,
|
||||
symbolic_checkpoint: builtins.bool | None = ...,
|
||||
warm_start: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "dataset_name", b"dataset_name", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_dataset_name", b"optional_dataset_name", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "optional_warm_start", b"optional_warm_start", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options", "warm_start", b"warm_start"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["autotune_options", b"autotune_options", "dataset_name", b"dataset_name", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "framework_type", b"framework_type", "optimization_options", b"optimization_options", "optional_dataset_name", b"optional_dataset_name", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "optional_warm_start", b"optional_warm_start", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options", "warm_start", b"warm_start"]) -> None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_dataset_name", b"optional_dataset_name"]) -> typing.Literal["dataset_name"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_deterministic", b"optional_deterministic"]) -> typing.Literal["deterministic"] | None: ...
|
||||
@typing.overload
|
||||
@@ -351,5 +372,7 @@ class Options(google.protobuf.message.Message):
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_slack", b"optional_slack"]) -> typing.Literal["slack"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_symbolic_checkpoint", b"optional_symbolic_checkpoint"]) -> typing.Literal["symbolic_checkpoint"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["optional_warm_start", b"optional_warm_start"]) -> typing.Literal["warm_start"] | None: ...
|
||||
|
||||
global___Options = Options
|
||||
|
||||
@@ -49,6 +49,7 @@ class GraphDebugInfo(google.protobuf.message.Message):
|
||||
func: builtins.str | None = ...,
|
||||
code: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["code", b"code", "col", b"col", "file_index", b"file_index", "func", b"func", "line", b"line"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["code", b"code", "col", b"col", "file_index", b"file_index", "func", b"func", "line", b"line"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
@@ -58,16 +59,56 @@ class GraphDebugInfo(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
FILE_LINE_COLS_FIELD_NUMBER: builtins.int
|
||||
FRAME_ID_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def file_line_cols(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___GraphDebugInfo.FileLineCol]:
|
||||
"""Each line in the stack trace."""
|
||||
"""Deprecated."""
|
||||
|
||||
@property
|
||||
def frame_id(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
file_line_cols: collections.abc.Iterable[global___GraphDebugInfo.FileLineCol] | None = ...,
|
||||
frame_id: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["file_line_cols", b"file_line_cols"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["file_line_cols", b"file_line_cols", "frame_id", b"frame_id"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class FramesByIdEntry(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
key: builtins.int
|
||||
@property
|
||||
def value(self) -> global___GraphDebugInfo.FileLineCol: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.int | None = ...,
|
||||
value: global___GraphDebugInfo.FileLineCol | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class TracesByIdEntry(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
key: builtins.int
|
||||
@property
|
||||
def value(self) -> global___GraphDebugInfo.StackTrace: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.int | None = ...,
|
||||
value: global___GraphDebugInfo.StackTrace | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class TracesEntry(google.protobuf.message.Message):
|
||||
@@ -84,24 +125,58 @@ class GraphDebugInfo(google.protobuf.message.Message):
|
||||
key: builtins.str | None = ...,
|
||||
value: global___GraphDebugInfo.StackTrace | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ...
|
||||
def HasField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class NameToTraceIdEntry(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
key: builtins.str
|
||||
value: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.str | None = ...,
|
||||
value: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
FILES_FIELD_NUMBER: builtins.int
|
||||
FRAMES_BY_ID_FIELD_NUMBER: builtins.int
|
||||
TRACES_BY_ID_FIELD_NUMBER: builtins.int
|
||||
TRACES_FIELD_NUMBER: builtins.int
|
||||
NAME_TO_TRACE_ID_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def files(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
|
||||
"""This stores all the source code file names and can be indexed by the
|
||||
`file_index`.
|
||||
"""
|
||||
|
||||
@property
|
||||
def frames_by_id(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, global___GraphDebugInfo.FileLineCol]:
|
||||
"""Stack traces and frames are uniqueified during construction. These maps
|
||||
index from the unique id for a frame/trace to the value.
|
||||
"""
|
||||
|
||||
@property
|
||||
def traces_by_id(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, global___GraphDebugInfo.StackTrace]: ...
|
||||
@property
|
||||
def traces(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___GraphDebugInfo.StackTrace]:
|
||||
"""This maps a node name to a stack trace in the source code.
|
||||
"""Deprecated."""
|
||||
|
||||
@property
|
||||
def name_to_trace_id(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.int]:
|
||||
"""This maps a node name to a trace id contained in `traces_by_id`.
|
||||
|
||||
The map key is a mangling of the containing function and op name with
|
||||
syntax:
|
||||
op.name '@' func_name
|
||||
For ops in the top-level graph, the func_name is the empty string.
|
||||
For ops in the top-level graph, the func_name is the empty string and hence
|
||||
the `@` may be ommitted.
|
||||
Note that op names are restricted to a small number of characters which
|
||||
exclude '@', making it impossible to collide keys of this form. Function
|
||||
names accept a much wider set of characters.
|
||||
@@ -113,8 +188,11 @@ class GraphDebugInfo(google.protobuf.message.Message):
|
||||
self,
|
||||
*,
|
||||
files: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
frames_by_id: collections.abc.Mapping[builtins.int, global___GraphDebugInfo.FileLineCol] | None = ...,
|
||||
traces_by_id: collections.abc.Mapping[builtins.int, global___GraphDebugInfo.StackTrace] | None = ...,
|
||||
traces: collections.abc.Mapping[builtins.str, global___GraphDebugInfo.StackTrace] | None = ...,
|
||||
name_to_trace_id: collections.abc.Mapping[builtins.str, builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["files", b"files", "traces", b"traces"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["files", b"files", "frames_by_id", b"frames_by_id", "name_to_trace_id", b"name_to_trace_id", "traces", b"traces", "traces_by_id", b"traces_by_id"]) -> None: ...
|
||||
|
||||
global___GraphDebugInfo = GraphDebugInfo
|
||||
@@ -11,6 +11,7 @@ import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
import tensorflow.core.framework.function_pb2
|
||||
import tensorflow.core.framework.graph_debug_info_pb2
|
||||
import tensorflow.core.framework.node_def_pb2
|
||||
import tensorflow.core.framework.versions_pb2
|
||||
|
||||
@@ -26,6 +27,7 @@ class GraphDef(google.protobuf.message.Message):
|
||||
VERSIONS_FIELD_NUMBER: builtins.int
|
||||
VERSION_FIELD_NUMBER: builtins.int
|
||||
LIBRARY_FIELD_NUMBER: builtins.int
|
||||
DEBUG_INFO_FIELD_NUMBER: builtins.int
|
||||
version: builtins.int
|
||||
"""Deprecated single version field; use versions above instead. Since all
|
||||
GraphDef changes before "versions" was introduced were forward
|
||||
@@ -70,6 +72,10 @@ class GraphDef(google.protobuf.message.Message):
|
||||
function are ready.
|
||||
"""
|
||||
|
||||
@property
|
||||
def debug_info(self) -> tensorflow.core.framework.graph_debug_info_pb2.GraphDebugInfo:
|
||||
"""Stack traces for the nodes in this graph."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -77,8 +83,9 @@ class GraphDef(google.protobuf.message.Message):
|
||||
versions: tensorflow.core.framework.versions_pb2.VersionDef | None = ...,
|
||||
version: builtins.int | None = ...,
|
||||
library: tensorflow.core.framework.function_pb2.FunctionDefLibrary | None = ...,
|
||||
debug_info: tensorflow.core.framework.graph_debug_info_pb2.GraphDebugInfo | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["library", b"library", "versions", b"versions"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["library", b"library", "node", b"node", "version", b"version", "versions", b"versions"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["debug_info", b"debug_info", "library", b"library", "versions", b"versions"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["debug_info", b"debug_info", "library", b"library", "node", b"node", "version", b"version", "versions", b"versions"]) -> None: ...
|
||||
|
||||
global___GraphDef = GraphDef
|
||||
|
||||
@@ -251,10 +251,14 @@ class ModelProto(google.protobuf.message.Message):
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["algorithm", b"algorithm", "cpu_budget", b"cpu_budget", "model_input_time", b"model_input_time", "ram_budget", b"ram_budget"]) -> None: ...
|
||||
|
||||
DATASET_NAME_FIELD_NUMBER: builtins.int
|
||||
NODES_FIELD_NUMBER: builtins.int
|
||||
OUTPUT_FIELD_NUMBER: builtins.int
|
||||
ID_COUNTER_FIELD_NUMBER: builtins.int
|
||||
OPTIMIZATION_PARAMS_FIELD_NUMBER: builtins.int
|
||||
GAP_TIMES_FIELD_NUMBER: builtins.int
|
||||
dataset_name: builtins.str
|
||||
"""User-defined name for the dataset. Empty if no name was set."""
|
||||
output: builtins.int
|
||||
"""ID of the output node of this model."""
|
||||
id_counter: builtins.int
|
||||
@@ -265,15 +269,19 @@ class ModelProto(google.protobuf.message.Message):
|
||||
|
||||
@property
|
||||
def optimization_params(self) -> global___ModelProto.OptimizationParams: ...
|
||||
@property
|
||||
def gap_times(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dataset_name: builtins.str | None = ...,
|
||||
nodes: collections.abc.Mapping[builtins.int, global___ModelProto.Node] | None = ...,
|
||||
output: builtins.int | None = ...,
|
||||
id_counter: builtins.int | None = ...,
|
||||
optimization_params: global___ModelProto.OptimizationParams | None = ...,
|
||||
gap_times: collections.abc.Iterable[builtins.int] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["optimization_params", b"optimization_params"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["id_counter", b"id_counter", "nodes", b"nodes", "optimization_params", b"optimization_params", "output", b"output"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["dataset_name", b"dataset_name", "gap_times", b"gap_times", "id_counter", b"id_counter", "nodes", b"nodes", "optimization_params", b"optimization_params", "output", b"output"]) -> None: ...
|
||||
|
||||
global___ModelProto = ModelProto
|
||||
|
||||
@@ -5,14 +5,21 @@ isort:skip_file
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import sys
|
||||
import typing
|
||||
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import tensorflow.core.framework.graph_pb2
|
||||
import tensorflow.core.framework.types_pb2
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
@@ -25,6 +32,30 @@ class OptimizedFunctionGraph(google.protobuf.message.Message):
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _OptimizationSource:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _OptimizationSourceEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[OptimizedFunctionGraph._OptimizationSource.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
SOURCE_UNSPECIFIED: OptimizedFunctionGraph._OptimizationSource.ValueType # 0
|
||||
AOT: OptimizedFunctionGraph._OptimizationSource.ValueType # 1
|
||||
JIT: OptimizedFunctionGraph._OptimizationSource.ValueType # 2
|
||||
|
||||
class OptimizationSource(_OptimizationSource, metaclass=_OptimizationSourceEnumTypeWrapper):
|
||||
"""Enum for distinguishing the origin where the proto is created.
|
||||
|
||||
AOT: proto is created in ahead-of-time environment, which can be different
|
||||
from the environment where the graph is actually executed.
|
||||
|
||||
JIT: proto is created in just-in-time execution, which has the same
|
||||
environment as the one the graph is actually executed.
|
||||
"""
|
||||
|
||||
SOURCE_UNSPECIFIED: OptimizedFunctionGraph.OptimizationSource.ValueType # 0
|
||||
AOT: OptimizedFunctionGraph.OptimizationSource.ValueType # 1
|
||||
JIT: OptimizedFunctionGraph.OptimizationSource.ValueType # 2
|
||||
|
||||
@typing.final
|
||||
class NodeNameToControlRetEntry(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
@@ -46,12 +77,20 @@ class OptimizedFunctionGraph(google.protobuf.message.Message):
|
||||
NODE_NAME_TO_CONTROL_RET_FIELD_NUMBER: builtins.int
|
||||
RET_TYPES_FIELD_NUMBER: builtins.int
|
||||
NUM_RETURN_NODES_FIELD_NUMBER: builtins.int
|
||||
SOURCE_FIELD_NUMBER: builtins.int
|
||||
OPTIMIZATION_TIME_USECS_FIELD_NUMBER: builtins.int
|
||||
name: builtins.str
|
||||
"""Function name. It can be a human-readable SignatureDef's method name, or a
|
||||
FunctionDef name.
|
||||
"""
|
||||
num_return_nodes: builtins.int
|
||||
"""Number of return nodes. This is an output of graph preprocessing."""
|
||||
source: global___OptimizedFunctionGraph.OptimizationSource.ValueType
|
||||
"""Indicates the source environment where this proto is generated."""
|
||||
optimization_time_usecs: builtins.int
|
||||
"""Time (in microseconds) spent on running the graph optimization passes for
|
||||
this function.
|
||||
"""
|
||||
@property
|
||||
def function_graph(self) -> tensorflow.core.framework.graph_pb2.GraphDef:
|
||||
"""Optimized function graph."""
|
||||
@@ -76,8 +115,14 @@ class OptimizedFunctionGraph(google.protobuf.message.Message):
|
||||
node_name_to_control_ret: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
ret_types: collections.abc.Iterable[tensorflow.core.framework.types_pb2.DataType.ValueType] | None = ...,
|
||||
num_return_nodes: builtins.int | None = ...,
|
||||
source: global___OptimizedFunctionGraph.OptimizationSource.ValueType | None = ...,
|
||||
optimization_time_usecs: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["function_graph", b"function_graph"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["function_graph", b"function_graph", "name", b"name", "node_name_to_control_ret", b"node_name_to_control_ret", "num_return_nodes", b"num_return_nodes", "ret_types", b"ret_types"]) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_optimization_time_usecs", b"_optimization_time_usecs", "_source", b"_source", "function_graph", b"function_graph", "optimization_time_usecs", b"optimization_time_usecs", "source", b"source"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_optimization_time_usecs", b"_optimization_time_usecs", "_source", b"_source", "function_graph", b"function_graph", "name", b"name", "node_name_to_control_ret", b"node_name_to_control_ret", "num_return_nodes", b"num_return_nodes", "optimization_time_usecs", b"optimization_time_usecs", "ret_types", b"ret_types", "source", b"source"]) -> None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_optimization_time_usecs", b"_optimization_time_usecs"]) -> typing.Literal["optimization_time_usecs"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ...
|
||||
|
||||
global___OptimizedFunctionGraph = OptimizedFunctionGraph
|
||||
|
||||
@@ -67,11 +67,18 @@ class _DataTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumT
|
||||
"""5 exponent bits, 2 mantissa bits."""
|
||||
DT_FLOAT8_E4M3FN: _DataType.ValueType # 25
|
||||
"""4 exponent bits, 3 mantissa bits, finite-only, with"""
|
||||
DT_FLOAT_REF: _DataType.ValueType # 101
|
||||
DT_INT4: _DataType.ValueType # 29
|
||||
"""2 NaNs (0bS1111111).
|
||||
|
||||
Do not use! These are only for parameters. Every enum above
|
||||
should have a corresponding value below (verified by types_test).
|
||||
TODO - b/299182407: Leaving room for remaining float8 types.
|
||||
DT_FLOAT8_E4M3FNUZ = 26;
|
||||
DT_FLOAT8_E4M3B11FNUZ = 27;
|
||||
DT_FLOAT8_E5M2FNUZ = 28;
|
||||
"""
|
||||
DT_UINT4: _DataType.ValueType # 30
|
||||
DT_FLOAT_REF: _DataType.ValueType # 101
|
||||
"""Do not use! These are only for TF1's obsolete reference Variables.
|
||||
Every enum above should have a corresponding value below (verified by
|
||||
types_test).
|
||||
"""
|
||||
DT_DOUBLE_REF: _DataType.ValueType # 102
|
||||
DT_INT32_REF: _DataType.ValueType # 103
|
||||
@@ -97,6 +104,13 @@ class _DataTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumT
|
||||
DT_UINT64_REF: _DataType.ValueType # 123
|
||||
DT_FLOAT8_E5M2_REF: _DataType.ValueType # 124
|
||||
DT_FLOAT8_E4M3FN_REF: _DataType.ValueType # 125
|
||||
DT_INT4_REF: _DataType.ValueType # 129
|
||||
"""TODO - b/299182407: Leaving room for remaining float8 types.
|
||||
DT_FLOAT8_E4M3FNUZ_REF = 126;
|
||||
DT_FLOAT8_E4M3B11FNUZ_REF = 127;
|
||||
DT_FLOAT8_E5M2FNUZ_REF = 128;
|
||||
"""
|
||||
DT_UINT4_REF: _DataType.ValueType # 130
|
||||
|
||||
class DataType(_DataType, metaclass=_DataTypeEnumTypeWrapper):
|
||||
"""(== suppress_warning documentation-presence ==)
|
||||
@@ -146,11 +160,18 @@ DT_FLOAT8_E5M2: DataType.ValueType # 24
|
||||
"""5 exponent bits, 2 mantissa bits."""
|
||||
DT_FLOAT8_E4M3FN: DataType.ValueType # 25
|
||||
"""4 exponent bits, 3 mantissa bits, finite-only, with"""
|
||||
DT_FLOAT_REF: DataType.ValueType # 101
|
||||
DT_INT4: DataType.ValueType # 29
|
||||
"""2 NaNs (0bS1111111).
|
||||
|
||||
Do not use! These are only for parameters. Every enum above
|
||||
should have a corresponding value below (verified by types_test).
|
||||
TODO - b/299182407: Leaving room for remaining float8 types.
|
||||
DT_FLOAT8_E4M3FNUZ = 26;
|
||||
DT_FLOAT8_E4M3B11FNUZ = 27;
|
||||
DT_FLOAT8_E5M2FNUZ = 28;
|
||||
"""
|
||||
DT_UINT4: DataType.ValueType # 30
|
||||
DT_FLOAT_REF: DataType.ValueType # 101
|
||||
"""Do not use! These are only for TF1's obsolete reference Variables.
|
||||
Every enum above should have a corresponding value below (verified by
|
||||
types_test).
|
||||
"""
|
||||
DT_DOUBLE_REF: DataType.ValueType # 102
|
||||
DT_INT32_REF: DataType.ValueType # 103
|
||||
@@ -176,6 +197,13 @@ DT_UINT32_REF: DataType.ValueType # 122
|
||||
DT_UINT64_REF: DataType.ValueType # 123
|
||||
DT_FLOAT8_E5M2_REF: DataType.ValueType # 124
|
||||
DT_FLOAT8_E4M3FN_REF: DataType.ValueType # 125
|
||||
DT_INT4_REF: DataType.ValueType # 129
|
||||
"""TODO - b/299182407: Leaving room for remaining float8 types.
|
||||
DT_FLOAT8_E4M3FNUZ_REF = 126;
|
||||
DT_FLOAT8_E4M3B11FNUZ_REF = 127;
|
||||
DT_FLOAT8_E5M2FNUZ_REF = 128;
|
||||
"""
|
||||
DT_UINT4_REF: DataType.ValueType # 130
|
||||
global___DataType = DataType
|
||||
|
||||
@typing.final
|
||||
|
||||
@@ -102,6 +102,9 @@ class JobDef(google.protobuf.message.Message):
|
||||
If the `name` field contains "worker", and the `tasks` map contains a
|
||||
mapping from 7 to "example.org:2222", then the device prefix
|
||||
"/job:worker/task:7" will be assigned to "example.org:2222".
|
||||
|
||||
If a job has multiple replicas, host-ports will be comma-delimited, with
|
||||
one entry for each replica.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -52,8 +52,8 @@ class GPUOptions(google.protobuf.message.Message):
|
||||
"""Per "virtual" device memory limit, in MB. The number of elements in
|
||||
the list is the number of virtual devices to create on the
|
||||
corresponding visible GPU (see "virtual_devices" below).
|
||||
If empty, it will create single virtual device taking all available
|
||||
memory from the device.
|
||||
If empty and `num_virtual_devices_per_gpu` is not set, it will create
|
||||
single virtual device taking all available memory from the device.
|
||||
|
||||
For the concept of "visible" and "virtual" GPU, see the comments for
|
||||
"visible_device_list" above for more information.
|
||||
@@ -91,6 +91,7 @@ class GPUOptions(google.protobuf.message.Message):
|
||||
def ClearField(self, field_name: typing.Literal["device_ordinal", b"device_ordinal", "memory_limit_mb", b"memory_limit_mb", "priority", b"priority"]) -> None: ...
|
||||
|
||||
VIRTUAL_DEVICES_FIELD_NUMBER: builtins.int
|
||||
NUM_VIRTUAL_DEVICES_PER_GPU_FIELD_NUMBER: builtins.int
|
||||
USE_UNIFIED_MEMORY_FIELD_NUMBER: builtins.int
|
||||
NUM_DEV_TO_DEV_COPY_STREAMS_FIELD_NUMBER: builtins.int
|
||||
COLLECTIVE_RING_ORDER_FIELD_NUMBER: builtins.int
|
||||
@@ -103,6 +104,13 @@ class GPUOptions(google.protobuf.message.Message):
|
||||
DISALLOW_RETRY_ON_ALLOCATION_FAILURE_FIELD_NUMBER: builtins.int
|
||||
GPU_HOST_MEM_LIMIT_IN_MB_FIELD_NUMBER: builtins.int
|
||||
GPU_HOST_MEM_DISALLOW_GROWTH_FIELD_NUMBER: builtins.int
|
||||
GPU_SYSTEM_MEMORY_SIZE_IN_MB_FIELD_NUMBER: builtins.int
|
||||
num_virtual_devices_per_gpu: builtins.int
|
||||
"""The number of virtual devices to create on each visible GPU. The
|
||||
available memory will be split equally among all virtual devices. If the
|
||||
field `memory_limit_mb` in `VirtualDevices` is not empty, this field will
|
||||
be ignored.
|
||||
"""
|
||||
use_unified_memory: builtins.bool
|
||||
"""If true, uses CUDA unified memory for memory allocations. If
|
||||
per_process_gpu_memory_fraction option is greater than 1.0, then unified
|
||||
@@ -185,6 +193,13 @@ class GPUOptions(google.protobuf.message.Message):
|
||||
gpu_host_mem_limit_in_mb, because the default GPU host memory limit is
|
||||
quite high.
|
||||
"""
|
||||
gpu_system_memory_size_in_mb: builtins.int
|
||||
"""Memory limit for gpu system. This can also be set by
|
||||
TF_DEVICE_MIN_SYS_MEMORY_IN_MB, which takes precedence over
|
||||
gpu_system_memory_size_in_mb. With this, user can configure the gpu
|
||||
system memory size for better resource estimation of multi-tenancy(one
|
||||
gpu with multiple model) use case.
|
||||
"""
|
||||
@property
|
||||
def virtual_devices(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___GPUOptions.Experimental.VirtualDevices]:
|
||||
"""The multi virtual device settings. If empty (not set), it will create
|
||||
@@ -231,6 +246,7 @@ class GPUOptions(google.protobuf.message.Message):
|
||||
self,
|
||||
*,
|
||||
virtual_devices: collections.abc.Iterable[global___GPUOptions.Experimental.VirtualDevices] | None = ...,
|
||||
num_virtual_devices_per_gpu: builtins.int | None = ...,
|
||||
use_unified_memory: builtins.bool | None = ...,
|
||||
num_dev_to_dev_copy_streams: builtins.int | None = ...,
|
||||
collective_ring_order: builtins.str | None = ...,
|
||||
@@ -243,8 +259,9 @@ class GPUOptions(google.protobuf.message.Message):
|
||||
disallow_retry_on_allocation_failure: builtins.bool | None = ...,
|
||||
gpu_host_mem_limit_in_mb: builtins.float | None = ...,
|
||||
gpu_host_mem_disallow_growth: builtins.bool | None = ...,
|
||||
gpu_system_memory_size_in_mb: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["collective_ring_order", b"collective_ring_order", "disallow_retry_on_allocation_failure", b"disallow_retry_on_allocation_failure", "gpu_host_mem_disallow_growth", b"gpu_host_mem_disallow_growth", "gpu_host_mem_limit_in_mb", b"gpu_host_mem_limit_in_mb", "internal_fragmentation_fraction", b"internal_fragmentation_fraction", "kernel_tracker_max_bytes", b"kernel_tracker_max_bytes", "kernel_tracker_max_interval", b"kernel_tracker_max_interval", "kernel_tracker_max_pending", b"kernel_tracker_max_pending", "num_dev_to_dev_copy_streams", b"num_dev_to_dev_copy_streams", "timestamped_allocator", b"timestamped_allocator", "use_cuda_malloc_async", b"use_cuda_malloc_async", "use_unified_memory", b"use_unified_memory", "virtual_devices", b"virtual_devices"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["collective_ring_order", b"collective_ring_order", "disallow_retry_on_allocation_failure", b"disallow_retry_on_allocation_failure", "gpu_host_mem_disallow_growth", b"gpu_host_mem_disallow_growth", "gpu_host_mem_limit_in_mb", b"gpu_host_mem_limit_in_mb", "gpu_system_memory_size_in_mb", b"gpu_system_memory_size_in_mb", "internal_fragmentation_fraction", b"internal_fragmentation_fraction", "kernel_tracker_max_bytes", b"kernel_tracker_max_bytes", "kernel_tracker_max_interval", b"kernel_tracker_max_interval", "kernel_tracker_max_pending", b"kernel_tracker_max_pending", "num_dev_to_dev_copy_streams", b"num_dev_to_dev_copy_streams", "num_virtual_devices_per_gpu", b"num_virtual_devices_per_gpu", "timestamped_allocator", b"timestamped_allocator", "use_cuda_malloc_async", b"use_cuda_malloc_async", "use_unified_memory", b"use_unified_memory", "virtual_devices", b"virtual_devices"]) -> None: ...
|
||||
|
||||
PER_PROCESS_GPU_MEMORY_FRACTION_FIELD_NUMBER: builtins.int
|
||||
ALLOW_GROWTH_FIELD_NUMBER: builtins.int
|
||||
@@ -695,10 +712,16 @@ class ConfigProto(google.protobuf.message.Message):
|
||||
DISABLE_OUTPUT_PARTITION_GRAPHS_FIELD_NUMBER: builtins.int
|
||||
XLA_FUSION_AUTOTUNER_THRESH_FIELD_NUMBER: builtins.int
|
||||
USE_TFRT_FIELD_NUMBER: builtins.int
|
||||
ENABLE_MULTI_HOST_FIELD_NUMBER: builtins.int
|
||||
BACKEND_SERVER_PORT_FIELD_NUMBER: builtins.int
|
||||
TARGET_TPU_FIELD_NUMBER: builtins.int
|
||||
TARGET_GPU_FIELD_NUMBER: builtins.int
|
||||
STREAM_MERGE_THRESHOLD_FIELD_NUMBER: builtins.int
|
||||
DISABLE_FUNCTIONAL_OPS_LOWERING_FIELD_NUMBER: builtins.int
|
||||
XLA_PREFER_SINGLE_GRAPH_CLUSTER_FIELD_NUMBER: builtins.int
|
||||
COORDINATION_CONFIG_FIELD_NUMBER: builtins.int
|
||||
DISABLE_OPTIMIZE_FOR_STATIC_GRAPH_FIELD_NUMBER: builtins.int
|
||||
DISABLE_EAGER_EXECUTOR_STREAMING_ENQUEUE_FIELD_NUMBER: builtins.int
|
||||
collective_group_leader: builtins.str
|
||||
"""Task name for group resolution."""
|
||||
executor_type: builtins.str
|
||||
@@ -801,6 +824,23 @@ class ConfigProto(google.protobuf.message.Message):
|
||||
"""
|
||||
use_tfrt: builtins.bool
|
||||
"""Whether runtime execution uses TFRT."""
|
||||
enable_multi_host: builtins.bool
|
||||
"""If true, use Pathways with TFRT API for multi host support."""
|
||||
backend_server_port: builtins.int
|
||||
"""Port for the Pathways server. Ignored if enable_multi_host=false."""
|
||||
target_tpu: builtins.bool
|
||||
"""If true, TFRT will use TPU specific compiler passes and perform TPU
|
||||
specific initialization.
|
||||
"""
|
||||
target_gpu: builtins.bool
|
||||
"""If true, TFRT will use GPU specific compiler passes and perform GPU
|
||||
specific initialization.
|
||||
"""
|
||||
stream_merge_threshold: builtins.int
|
||||
"""The threshold to merge small streams in TFRT. The stream with cost
|
||||
smaller than the threshold will be merged. Setting it to value 1
|
||||
disables all merges.
|
||||
"""
|
||||
disable_functional_ops_lowering: builtins.bool
|
||||
"""Whether functional control flow op lowering should be disabled. This is
|
||||
useful when executing within a portable runtime where control flow op
|
||||
@@ -821,6 +861,11 @@ class ConfigProto(google.protobuf.message.Message):
|
||||
This option is meant to replace `optimize_for_static_graph` and it
|
||||
aims to negate its value.
|
||||
"""
|
||||
disable_eager_executor_streaming_enqueue: builtins.bool
|
||||
"""Whether eager remote execution will stream all the function calls or
|
||||
allow them to happen in parallel. When true, streaming execution is
|
||||
disabled, and parallel execution is allowed.
|
||||
"""
|
||||
@property
|
||||
def session_metadata(self) -> global___SessionMetadata:
|
||||
"""Metadata about the session.
|
||||
@@ -856,13 +901,19 @@ class ConfigProto(google.protobuf.message.Message):
|
||||
disable_output_partition_graphs: builtins.bool | None = ...,
|
||||
xla_fusion_autotuner_thresh: builtins.int | None = ...,
|
||||
use_tfrt: builtins.bool | None = ...,
|
||||
enable_multi_host: builtins.bool | None = ...,
|
||||
backend_server_port: builtins.int | None = ...,
|
||||
target_tpu: builtins.bool | None = ...,
|
||||
target_gpu: builtins.bool | None = ...,
|
||||
stream_merge_threshold: builtins.int | None = ...,
|
||||
disable_functional_ops_lowering: builtins.bool | None = ...,
|
||||
xla_prefer_single_graph_cluster: builtins.bool | None = ...,
|
||||
coordination_config: tensorflow.tsl.protobuf.coordination_config_pb2.CoordinationServiceConfig | None = ...,
|
||||
disable_optimize_for_static_graph: builtins.bool | None = ...,
|
||||
disable_eager_executor_streaming_enqueue: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["coordination_config", b"coordination_config", "session_metadata", b"session_metadata"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["collective_deterministic_sequential_execution", b"collective_deterministic_sequential_execution", "collective_group_leader", b"collective_group_leader", "collective_nccl", b"collective_nccl", "coordination_config", b"coordination_config", "disable_functional_ops_lowering", b"disable_functional_ops_lowering", "disable_optimize_for_static_graph", b"disable_optimize_for_static_graph", "disable_output_partition_graphs", b"disable_output_partition_graphs", "disable_thread_spinning", b"disable_thread_spinning", "enable_mlir_bridge", b"enable_mlir_bridge", "enable_mlir_graph_optimization", b"enable_mlir_graph_optimization", "executor_type", b"executor_type", "mlir_bridge_rollout", b"mlir_bridge_rollout", "optimize_for_static_graph", b"optimize_for_static_graph", "recv_buf_max_chunk", b"recv_buf_max_chunk", "session_metadata", b"session_metadata", "share_cluster_devices_in_session", b"share_cluster_devices_in_session", "share_session_state_in_clusterspec_propagation", b"share_session_state_in_clusterspec_propagation", "use_numa_affinity", b"use_numa_affinity", "use_tfrt", b"use_tfrt", "xla_fusion_autotuner_thresh", b"xla_fusion_autotuner_thresh", "xla_prefer_single_graph_cluster", b"xla_prefer_single_graph_cluster"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["backend_server_port", b"backend_server_port", "collective_deterministic_sequential_execution", b"collective_deterministic_sequential_execution", "collective_group_leader", b"collective_group_leader", "collective_nccl", b"collective_nccl", "coordination_config", b"coordination_config", "disable_eager_executor_streaming_enqueue", b"disable_eager_executor_streaming_enqueue", "disable_functional_ops_lowering", b"disable_functional_ops_lowering", "disable_optimize_for_static_graph", b"disable_optimize_for_static_graph", "disable_output_partition_graphs", b"disable_output_partition_graphs", "disable_thread_spinning", b"disable_thread_spinning", "enable_mlir_bridge", b"enable_mlir_bridge", "enable_mlir_graph_optimization", b"enable_mlir_graph_optimization", "enable_multi_host", b"enable_multi_host", "executor_type", b"executor_type", "mlir_bridge_rollout", b"mlir_bridge_rollout", "optimize_for_static_graph", b"optimize_for_static_graph", "recv_buf_max_chunk", b"recv_buf_max_chunk", "session_metadata", b"session_metadata", "share_cluster_devices_in_session", b"share_cluster_devices_in_session", "share_session_state_in_clusterspec_propagation", b"share_session_state_in_clusterspec_propagation", "stream_merge_threshold", b"stream_merge_threshold", "target_gpu", b"target_gpu", "target_tpu", b"target_tpu", "use_numa_affinity", b"use_numa_affinity", "use_tfrt", b"use_tfrt", "xla_fusion_autotuner_thresh", b"xla_fusion_autotuner_thresh", "xla_prefer_single_graph_cluster", b"xla_prefer_single_graph_cluster"]) -> None: ...
|
||||
|
||||
DEVICE_COUNT_FIELD_NUMBER: builtins.int
|
||||
INTRA_OP_PARALLELISM_THREADS_FIELD_NUMBER: builtins.int
|
||||
|
||||
@@ -12,8 +12,8 @@ import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import tensorflow.core.framework.graph_debug_info_pb2
|
||||
import tensorflow.core.framework.tensor_pb2
|
||||
import tensorflow.core.protobuf.graph_debug_info_pb2
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
@@ -288,7 +288,7 @@ class StackFrameWithId(google.protobuf.message.Message):
|
||||
id: builtins.str
|
||||
"""A unique ID for the stack frame: A UUID-like string."""
|
||||
@property
|
||||
def file_line_col(self) -> tensorflow.core.protobuf.graph_debug_info_pb2.GraphDebugInfo.FileLineCol:
|
||||
def file_line_col(self) -> tensorflow.core.framework.graph_debug_info_pb2.GraphDebugInfo.FileLineCol:
|
||||
"""Stack frame, i.e., a frame of a stack trace, containing information
|
||||
regarding the file name, line number, function name, code content
|
||||
of the line, and column number (if available).
|
||||
@@ -298,7 +298,7 @@ class StackFrameWithId(google.protobuf.message.Message):
|
||||
self,
|
||||
*,
|
||||
id: builtins.str | None = ...,
|
||||
file_line_col: tensorflow.core.protobuf.graph_debug_info_pb2.GraphDebugInfo.FileLineCol | None = ...,
|
||||
file_line_col: tensorflow.core.framework.graph_debug_info_pb2.GraphDebugInfo.FileLineCol | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["file_line_col", b"file_line_col"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["file_line_col", b"file_line_col", "id", b"id"]) -> None: ...
|
||||
|
||||
@@ -40,7 +40,9 @@ class FingerprintDef(google.protobuf.message.Message):
|
||||
"""Hash of the checkpoint."""
|
||||
@property
|
||||
def version(self) -> tensorflow.core.framework.versions_pb2.VersionDef:
|
||||
"""Version specification of the fingerprint."""
|
||||
"""Version specification of the fingerprint.
|
||||
TODO(b/290068219): add USM version when GA
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -13,6 +13,7 @@ import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
import tensorflow.core.framework.graph_pb2
|
||||
import tensorflow.core.framework.op_def_pb2
|
||||
import tensorflow.core.framework.tensor_pb2
|
||||
import tensorflow.core.framework.tensor_shape_pb2
|
||||
import tensorflow.core.framework.types_pb2
|
||||
import tensorflow.core.protobuf.saved_object_graph_pb2
|
||||
@@ -621,9 +622,28 @@ class SignatureDef(google.protobuf.message.Message):
|
||||
def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class DefaultsEntry(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
key: builtins.str
|
||||
@property
|
||||
def value(self) -> tensorflow.core.framework.tensor_pb2.TensorProto: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.str | None = ...,
|
||||
value: tensorflow.core.framework.tensor_pb2.TensorProto | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["value", b"value"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ...
|
||||
|
||||
INPUTS_FIELD_NUMBER: builtins.int
|
||||
OUTPUTS_FIELD_NUMBER: builtins.int
|
||||
METHOD_NAME_FIELD_NUMBER: builtins.int
|
||||
DEFAULTS_FIELD_NUMBER: builtins.int
|
||||
method_name: builtins.str
|
||||
"""Extensible method_name information enabling third-party users to mark a
|
||||
SignatureDef as supporting a particular method. This enables producers and
|
||||
@@ -642,14 +662,19 @@ class SignatureDef(google.protobuf.message.Message):
|
||||
def outputs(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___TensorInfo]:
|
||||
"""Named output parameters."""
|
||||
|
||||
@property
|
||||
def defaults(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, tensorflow.core.framework.tensor_pb2.TensorProto]:
|
||||
"""Named input to corresponding default values if any."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
inputs: collections.abc.Mapping[builtins.str, global___TensorInfo] | None = ...,
|
||||
outputs: collections.abc.Mapping[builtins.str, global___TensorInfo] | None = ...,
|
||||
method_name: builtins.str | None = ...,
|
||||
defaults: collections.abc.Mapping[builtins.str, tensorflow.core.framework.tensor_pb2.TensorProto] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["inputs", b"inputs", "method_name", b"method_name", "outputs", b"outputs"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["defaults", b"defaults", "inputs", b"inputs", "method_name", b"method_name", "outputs", b"outputs"]) -> None: ...
|
||||
|
||||
global___SignatureDef = SignatureDef
|
||||
|
||||
|
||||
@@ -266,6 +266,7 @@ class RewriterConfig(google.protobuf.message.Message):
|
||||
AUTO_MIXED_PRECISION_ONEDNN_BFLOAT16_FIELD_NUMBER: builtins.int
|
||||
AUTO_MIXED_PRECISION_CPU_FIELD_NUMBER: builtins.int
|
||||
DISABLE_META_OPTIMIZER_FIELD_NUMBER: builtins.int
|
||||
DISABLE_TFG_OPTIMIZER_FIELD_NUMBER: builtins.int
|
||||
USE_PLUGIN_OPTIMIZERS_FIELD_NUMBER: builtins.int
|
||||
EXPERIMENTAL_CONDITIONAL_CODE_MOTION_FIELD_NUMBER: builtins.int
|
||||
META_OPTIMIZER_ITERATIONS_FIELD_NUMBER: builtins.int
|
||||
@@ -359,6 +360,8 @@ class RewriterConfig(google.protobuf.message.Message):
|
||||
"""
|
||||
disable_meta_optimizer: builtins.bool
|
||||
"""Disable the entire meta optimizer (off by default)."""
|
||||
disable_tfg_optimizer: builtins.bool
|
||||
"""Disable the TFG optimizer (off by default)."""
|
||||
use_plugin_optimizers: global___RewriterConfig.Toggle.ValueType
|
||||
"""Optimizers registered by plugin (default is ON)"""
|
||||
experimental_conditional_code_motion: global___RewriterConfig.Toggle.ValueType
|
||||
@@ -471,6 +474,7 @@ class RewriterConfig(google.protobuf.message.Message):
|
||||
auto_mixed_precision_onednn_bfloat16: global___RewriterConfig.Toggle.ValueType | None = ...,
|
||||
auto_mixed_precision_cpu: global___RewriterConfig.Toggle.ValueType | None = ...,
|
||||
disable_meta_optimizer: builtins.bool | None = ...,
|
||||
disable_tfg_optimizer: builtins.bool | None = ...,
|
||||
use_plugin_optimizers: global___RewriterConfig.Toggle.ValueType | None = ...,
|
||||
experimental_conditional_code_motion: global___RewriterConfig.Toggle.ValueType | None = ...,
|
||||
meta_optimizer_iterations: global___RewriterConfig.NumIterationsType.ValueType | None = ...,
|
||||
@@ -489,6 +493,6 @@ class RewriterConfig(google.protobuf.message.Message):
|
||||
post_optimization_verifier_config: tensorflow.core.protobuf.verifier_config_pb2.VerifierConfig | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["auto_parallel", b"auto_parallel", "inter_optimizer_verifier_config", b"inter_optimizer_verifier_config", "post_optimization_verifier_config", b"post_optimization_verifier_config", "scoped_allocator_opts", b"scoped_allocator_opts"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["arithmetic_optimization", b"arithmetic_optimization", "auto_mixed_precision", b"auto_mixed_precision", "auto_mixed_precision_cpu", b"auto_mixed_precision_cpu", "auto_mixed_precision_mkl", b"auto_mixed_precision_mkl", "auto_mixed_precision_onednn_bfloat16", b"auto_mixed_precision_onednn_bfloat16", "auto_parallel", b"auto_parallel", "common_subgraph_elimination", b"common_subgraph_elimination", "constant_folding", b"constant_folding", "cpu_layout_conversion", b"cpu_layout_conversion", "custom_optimizers", b"custom_optimizers", "debug_stripper", b"debug_stripper", "dependency_optimization", b"dependency_optimization", "disable_meta_optimizer", b"disable_meta_optimizer", "disable_model_pruning", b"disable_model_pruning", "experimental_conditional_code_motion", b"experimental_conditional_code_motion", "experimental_disable_compressed_tensor_optimization", b"experimental_disable_compressed_tensor_optimization", "experimental_disable_folding_quantization_emulation", b"experimental_disable_folding_quantization_emulation", "fail_on_optimizer_errors", b"fail_on_optimizer_errors", "function_optimization", b"function_optimization", "implementation_selector", b"implementation_selector", "inter_optimizer_verifier_config", b"inter_optimizer_verifier_config", "layout_optimizer", b"layout_optimizer", "loop_optimization", b"loop_optimization", "memory_optimization", b"memory_optimization", "memory_optimizer_target_node_name_scope", b"memory_optimizer_target_node_name_scope", "meta_optimizer_iterations", b"meta_optimizer_iterations", "meta_optimizer_timeout_ms", b"meta_optimizer_timeout_ms", "min_graph_nodes", b"min_graph_nodes", "optimizers", b"optimizers", "pin_to_host_optimization", b"pin_to_host_optimization", "post_optimization_verifier_config", b"post_optimization_verifier_config", "remapping", b"remapping", "scoped_allocator_optimization", b"scoped_allocator_optimization", "scoped_allocator_opts", b"scoped_allocator_opts", "shape_optimization", b"shape_optimization", "use_plugin_optimizers", b"use_plugin_optimizers"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["arithmetic_optimization", b"arithmetic_optimization", "auto_mixed_precision", b"auto_mixed_precision", "auto_mixed_precision_cpu", b"auto_mixed_precision_cpu", "auto_mixed_precision_mkl", b"auto_mixed_precision_mkl", "auto_mixed_precision_onednn_bfloat16", b"auto_mixed_precision_onednn_bfloat16", "auto_parallel", b"auto_parallel", "common_subgraph_elimination", b"common_subgraph_elimination", "constant_folding", b"constant_folding", "cpu_layout_conversion", b"cpu_layout_conversion", "custom_optimizers", b"custom_optimizers", "debug_stripper", b"debug_stripper", "dependency_optimization", b"dependency_optimization", "disable_meta_optimizer", b"disable_meta_optimizer", "disable_model_pruning", b"disable_model_pruning", "disable_tfg_optimizer", b"disable_tfg_optimizer", "experimental_conditional_code_motion", b"experimental_conditional_code_motion", "experimental_disable_compressed_tensor_optimization", b"experimental_disable_compressed_tensor_optimization", "experimental_disable_folding_quantization_emulation", b"experimental_disable_folding_quantization_emulation", "fail_on_optimizer_errors", b"fail_on_optimizer_errors", "function_optimization", b"function_optimization", "implementation_selector", b"implementation_selector", "inter_optimizer_verifier_config", b"inter_optimizer_verifier_config", "layout_optimizer", b"layout_optimizer", "loop_optimization", b"loop_optimization", "memory_optimization", b"memory_optimization", "memory_optimizer_target_node_name_scope", b"memory_optimizer_target_node_name_scope", "meta_optimizer_iterations", b"meta_optimizer_iterations", "meta_optimizer_timeout_ms", b"meta_optimizer_timeout_ms", "min_graph_nodes", b"min_graph_nodes", "optimizers", b"optimizers", "pin_to_host_optimization", b"pin_to_host_optimization", "post_optimization_verifier_config", b"post_optimization_verifier_config", "remapping", b"remapping", "scoped_allocator_optimization", b"scoped_allocator_optimization", "scoped_allocator_opts", b"scoped_allocator_opts", "shape_optimization", b"shape_optimization", "use_plugin_optimizers", b"use_plugin_optimizers"]) -> None: ...
|
||||
|
||||
global___RewriterConfig = RewriterConfig
|
||||
|
||||
@@ -17,7 +17,7 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
@typing.final
|
||||
class DispatcherConfig(google.protobuf.message.Message):
|
||||
"""Configuration for a tf.data service DispatchServer.
|
||||
Next id: 11
|
||||
Next id: 13
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
@@ -30,8 +30,10 @@ class DispatcherConfig(google.protobuf.message.Message):
|
||||
DEPLOYMENT_MODE_FIELD_NUMBER: builtins.int
|
||||
JOB_GC_CHECK_INTERVAL_MS_FIELD_NUMBER: builtins.int
|
||||
JOB_GC_TIMEOUT_MS_FIELD_NUMBER: builtins.int
|
||||
GC_DYNAMIC_SHARDING_JOBS_FIELD_NUMBER: builtins.int
|
||||
CLIENT_TIMEOUT_MS_FIELD_NUMBER: builtins.int
|
||||
WORKER_TIMEOUT_MS_FIELD_NUMBER: builtins.int
|
||||
WORKER_MAX_CONCURRENT_SNAPSHOTS_FIELD_NUMBER: builtins.int
|
||||
port: builtins.int
|
||||
"""The port for the dispatcher to bind to. A value of 0 indicates that the
|
||||
dispatcher may bind to any available port.
|
||||
@@ -59,7 +61,15 @@ class DispatcherConfig(google.protobuf.message.Message):
|
||||
"""How long a job needs to be unused before it becomes a candidate for garbage
|
||||
collection. A value of -1 indicates that jobs should never be garbage
|
||||
collected. A value of 0 indicates that the decision should be left up to
|
||||
the runtime.
|
||||
the runtime. Note: This does not apply to dynamic sharding unless users
|
||||
explicitly opt-in by enabling `gc_dynamic_sharding_jobs` below.
|
||||
"""
|
||||
gc_dynamic_sharding_jobs: builtins.bool
|
||||
"""Whether dynamically sharded jobs should be eligible for garbage collection.
|
||||
These jobs are not garbage collected by default, since if a job is garbage
|
||||
collected and then re-created, it will revisit all data from the start. If
|
||||
revisiting data is acceptible and you want automatic reclamation of
|
||||
iterator memory, set `gc_dynamic_sharding_jobs` to `true`.
|
||||
"""
|
||||
client_timeout_ms: builtins.int
|
||||
"""How long to wait before garbage-collecting a client that hasn't
|
||||
@@ -70,6 +80,12 @@ class DispatcherConfig(google.protobuf.message.Message):
|
||||
"""How long to wait for a worker to heartbeat before considering it missing.
|
||||
A value of 0 indicates that the timeout should be left to the runtime.
|
||||
"""
|
||||
worker_max_concurrent_snapshots: builtins.int
|
||||
"""The maximum number of snapshots that a worker can concurrently process at a
|
||||
given point in time. This is a tradeoff between worker resource usage and
|
||||
snapshot wall time. A value of 0 indicates that the decision should be left
|
||||
up to the runtime.
|
||||
"""
|
||||
@property
|
||||
def worker_addresses(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
|
||||
"""(Optional.) If the job uses auto-sharding, it needs to specify a fixed list
|
||||
@@ -89,17 +105,19 @@ class DispatcherConfig(google.protobuf.message.Message):
|
||||
deployment_mode: tensorflow.core.protobuf.data_service_pb2.DeploymentMode.ValueType | None = ...,
|
||||
job_gc_check_interval_ms: builtins.int | None = ...,
|
||||
job_gc_timeout_ms: builtins.int | None = ...,
|
||||
gc_dynamic_sharding_jobs: builtins.bool | None = ...,
|
||||
client_timeout_ms: builtins.int | None = ...,
|
||||
worker_timeout_ms: builtins.int | None = ...,
|
||||
worker_max_concurrent_snapshots: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["client_timeout_ms", b"client_timeout_ms", "deployment_mode", b"deployment_mode", "fault_tolerant_mode", b"fault_tolerant_mode", "job_gc_check_interval_ms", b"job_gc_check_interval_ms", "job_gc_timeout_ms", b"job_gc_timeout_ms", "port", b"port", "protocol", b"protocol", "work_dir", b"work_dir", "worker_addresses", b"worker_addresses", "worker_timeout_ms", b"worker_timeout_ms"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["client_timeout_ms", b"client_timeout_ms", "deployment_mode", b"deployment_mode", "fault_tolerant_mode", b"fault_tolerant_mode", "gc_dynamic_sharding_jobs", b"gc_dynamic_sharding_jobs", "job_gc_check_interval_ms", b"job_gc_check_interval_ms", "job_gc_timeout_ms", b"job_gc_timeout_ms", "port", b"port", "protocol", b"protocol", "work_dir", b"work_dir", "worker_addresses", b"worker_addresses", "worker_max_concurrent_snapshots", b"worker_max_concurrent_snapshots", "worker_timeout_ms", b"worker_timeout_ms"]) -> None: ...
|
||||
|
||||
global___DispatcherConfig = DispatcherConfig
|
||||
|
||||
@typing.final
|
||||
class WorkerConfig(google.protobuf.message.Message):
|
||||
"""Configuration for a tf.data service WorkerServer.
|
||||
Next id: 12
|
||||
Next id: 13
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
@@ -114,6 +132,7 @@ class WorkerConfig(google.protobuf.message.Message):
|
||||
DATA_TRANSFER_PROTOCOL_FIELD_NUMBER: builtins.int
|
||||
DATA_TRANSFER_ADDRESS_FIELD_NUMBER: builtins.int
|
||||
CROSS_TRAINER_CACHE_SIZE_BYTES_FIELD_NUMBER: builtins.int
|
||||
SNAPSHOT_MAX_CHUNK_SIZE_BYTES_FIELD_NUMBER: builtins.int
|
||||
SHUTDOWN_QUIET_PERIOD_MS_FIELD_NUMBER: builtins.int
|
||||
port: builtins.int
|
||||
"""The port for the worker to bind to. A value of 0 indicates that the
|
||||
@@ -148,6 +167,10 @@ class WorkerConfig(google.protobuf.message.Message):
|
||||
"""Maximum size of the cross-trainer cache in bytes. If enabled, make sure
|
||||
your training job provides sufficient memory resources.
|
||||
"""
|
||||
snapshot_max_chunk_size_bytes: builtins.int
|
||||
"""The maximum size of a distributed snapshot chunk file. A value of 0
|
||||
indicates that the decision should be left up to the runtime.
|
||||
"""
|
||||
shutdown_quiet_period_ms: builtins.int
|
||||
"""When shutting down a worker, how long to wait for the gRPC server to
|
||||
process the final requests. This is used to achieve clean shutdown in unit
|
||||
@@ -174,8 +197,9 @@ class WorkerConfig(google.protobuf.message.Message):
|
||||
data_transfer_protocol: builtins.str | None = ...,
|
||||
data_transfer_address: builtins.str | None = ...,
|
||||
cross_trainer_cache_size_bytes: builtins.int | None = ...,
|
||||
snapshot_max_chunk_size_bytes: builtins.int | None = ...,
|
||||
shutdown_quiet_period_ms: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["cross_trainer_cache_size_bytes", b"cross_trainer_cache_size_bytes", "data_transfer_address", b"data_transfer_address", "data_transfer_protocol", b"data_transfer_protocol", "dispatcher_address", b"dispatcher_address", "dispatcher_timeout_ms", b"dispatcher_timeout_ms", "heartbeat_interval_ms", b"heartbeat_interval_ms", "port", b"port", "protocol", b"protocol", "shutdown_quiet_period_ms", b"shutdown_quiet_period_ms", "worker_address", b"worker_address", "worker_tags", b"worker_tags"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["cross_trainer_cache_size_bytes", b"cross_trainer_cache_size_bytes", "data_transfer_address", b"data_transfer_address", "data_transfer_protocol", b"data_transfer_protocol", "dispatcher_address", b"dispatcher_address", "dispatcher_timeout_ms", b"dispatcher_timeout_ms", "heartbeat_interval_ms", b"heartbeat_interval_ms", "port", b"port", "protocol", b"protocol", "shutdown_quiet_period_ms", b"shutdown_quiet_period_ms", "snapshot_max_chunk_size_bytes", b"snapshot_max_chunk_size_bytes", "worker_address", b"worker_address", "worker_tags", b"worker_tags"]) -> None: ...
|
||||
|
||||
global___WorkerConfig = WorkerConfig
|
||||
|
||||
13
stubs/tensorflow/tensorflow/core/protobuf/status_pb2.pyi
Normal file
13
stubs/tensorflow/tensorflow/core/protobuf/status_pb2.pyi
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
Add a dummy package name. Having no package like
|
||||
core/lib/core/error_codes.proto, or having tensorflow like
|
||||
tsl/protobuf/error_codes.proto, results in name collision errors in generated
|
||||
code for some users that use JS through J2CL.
|
||||
"""
|
||||
|
||||
import google.protobuf.descriptor
|
||||
from tensorflow.tsl.protobuf.status_pb2 import StatusProto as StatusProto
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
@@ -67,6 +67,8 @@ class StructuredValue(google.protobuf.message.Message):
|
||||
TUPLE_VALUE_FIELD_NUMBER: builtins.int
|
||||
DICT_VALUE_FIELD_NUMBER: builtins.int
|
||||
NAMED_TUPLE_VALUE_FIELD_NUMBER: builtins.int
|
||||
TENSOR_VALUE_FIELD_NUMBER: builtins.int
|
||||
NUMPY_VALUE_FIELD_NUMBER: builtins.int
|
||||
float64_value: builtins.float
|
||||
"""Represents a double-precision floating-point value (a Python `float`)."""
|
||||
int64_value: builtins.int
|
||||
@@ -121,6 +123,14 @@ class StructuredValue(google.protobuf.message.Message):
|
||||
def named_tuple_value(self) -> global___NamedTupleValue:
|
||||
"""Represents Python's namedtuple."""
|
||||
|
||||
@property
|
||||
def tensor_value(self) -> tensorflow.core.framework.tensor_pb2.TensorProto:
|
||||
"""Represents a value for tf.Tensor."""
|
||||
|
||||
@property
|
||||
def numpy_value(self) -> tensorflow.core.framework.tensor_pb2.TensorProto:
|
||||
"""Represents a value for np.ndarray."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -138,10 +148,12 @@ class StructuredValue(google.protobuf.message.Message):
|
||||
tuple_value: global___TupleValue | None = ...,
|
||||
dict_value: global___DictValue | None = ...,
|
||||
named_tuple_value: global___NamedTupleValue | None = ...,
|
||||
tensor_value: tensorflow.core.framework.tensor_pb2.TensorProto | None = ...,
|
||||
numpy_value: tensorflow.core.framework.tensor_pb2.TensorProto | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["bool_value", b"bool_value", "bounded_tensor_spec_value", b"bounded_tensor_spec_value", "dict_value", b"dict_value", "float64_value", b"float64_value", "int64_value", b"int64_value", "kind", b"kind", "list_value", b"list_value", "named_tuple_value", b"named_tuple_value", "none_value", b"none_value", "string_value", b"string_value", "tensor_dtype_value", b"tensor_dtype_value", "tensor_shape_value", b"tensor_shape_value", "tensor_spec_value", b"tensor_spec_value", "tuple_value", b"tuple_value", "type_spec_value", b"type_spec_value"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["bool_value", b"bool_value", "bounded_tensor_spec_value", b"bounded_tensor_spec_value", "dict_value", b"dict_value", "float64_value", b"float64_value", "int64_value", b"int64_value", "kind", b"kind", "list_value", b"list_value", "named_tuple_value", b"named_tuple_value", "none_value", b"none_value", "string_value", b"string_value", "tensor_dtype_value", b"tensor_dtype_value", "tensor_shape_value", b"tensor_shape_value", "tensor_spec_value", b"tensor_spec_value", "tuple_value", b"tuple_value", "type_spec_value", b"type_spec_value"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["kind", b"kind"]) -> typing.Literal["none_value", "float64_value", "int64_value", "string_value", "bool_value", "tensor_shape_value", "tensor_dtype_value", "tensor_spec_value", "type_spec_value", "bounded_tensor_spec_value", "list_value", "tuple_value", "dict_value", "named_tuple_value"] | None: ...
|
||||
def HasField(self, field_name: typing.Literal["bool_value", b"bool_value", "bounded_tensor_spec_value", b"bounded_tensor_spec_value", "dict_value", b"dict_value", "float64_value", b"float64_value", "int64_value", b"int64_value", "kind", b"kind", "list_value", b"list_value", "named_tuple_value", b"named_tuple_value", "none_value", b"none_value", "numpy_value", b"numpy_value", "string_value", b"string_value", "tensor_dtype_value", b"tensor_dtype_value", "tensor_shape_value", b"tensor_shape_value", "tensor_spec_value", b"tensor_spec_value", "tensor_value", b"tensor_value", "tuple_value", b"tuple_value", "type_spec_value", b"type_spec_value"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["bool_value", b"bool_value", "bounded_tensor_spec_value", b"bounded_tensor_spec_value", "dict_value", b"dict_value", "float64_value", b"float64_value", "int64_value", b"int64_value", "kind", b"kind", "list_value", b"list_value", "named_tuple_value", b"named_tuple_value", "none_value", b"none_value", "numpy_value", b"numpy_value", "string_value", b"string_value", "tensor_dtype_value", b"tensor_dtype_value", "tensor_shape_value", b"tensor_shape_value", "tensor_spec_value", b"tensor_spec_value", "tensor_value", b"tensor_value", "tuple_value", b"tuple_value", "type_spec_value", b"type_spec_value"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["kind", b"kind"]) -> typing.Literal["none_value", "float64_value", "int64_value", "string_value", "bool_value", "tensor_shape_value", "tensor_dtype_value", "tensor_spec_value", "type_spec_value", "bounded_tensor_spec_value", "list_value", "tuple_value", "dict_value", "named_tuple_value", "tensor_value", "numpy_value"] | None: ...
|
||||
|
||||
global___StructuredValue = StructuredValue
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ class ServerDef(google.protobuf.message.Message):
|
||||
|
||||
CLUSTER_FIELD_NUMBER: builtins.int
|
||||
JOB_NAME_FIELD_NUMBER: builtins.int
|
||||
REPLICA_FIELD_NUMBER: builtins.int
|
||||
TASK_INDEX_FIELD_NUMBER: builtins.int
|
||||
DEFAULT_SESSION_CONFIG_FIELD_NUMBER: builtins.int
|
||||
PROTOCOL_FIELD_NUMBER: builtins.int
|
||||
@@ -47,6 +48,8 @@ class ServerDef(google.protobuf.message.Message):
|
||||
NOTE(mrry): The `cluster` field must contain a `JobDef` with a `name` field
|
||||
that matches this name.
|
||||
"""
|
||||
replica: builtins.int
|
||||
"""Replica this server manages."""
|
||||
task_index: builtins.int
|
||||
"""The task index of this server in its job.
|
||||
|
||||
@@ -79,6 +82,7 @@ class ServerDef(google.protobuf.message.Message):
|
||||
*,
|
||||
cluster: tensorflow.core.protobuf.cluster_pb2.ClusterDef | None = ...,
|
||||
job_name: builtins.str | None = ...,
|
||||
replica: builtins.int | None = ...,
|
||||
task_index: builtins.int | None = ...,
|
||||
default_session_config: tensorflow.core.protobuf.config_pb2.ConfigProto | None = ...,
|
||||
protocol: builtins.str | None = ...,
|
||||
@@ -86,6 +90,6 @@ class ServerDef(google.protobuf.message.Message):
|
||||
cluster_device_filters: tensorflow.core.protobuf.device_filters_pb2.ClusterDeviceFilters | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["cluster", b"cluster", "cluster_device_filters", b"cluster_device_filters", "default_session_config", b"default_session_config"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["cluster", b"cluster", "cluster_device_filters", b"cluster_device_filters", "default_session_config", b"default_session_config", "job_name", b"job_name", "port", b"port", "protocol", b"protocol", "task_index", b"task_index"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["cluster", b"cluster", "cluster_device_filters", b"cluster_device_filters", "default_session_config", b"default_session_config", "job_name", b"job_name", "port", b"port", "protocol", b"protocol", "replica", b"replica", "task_index", b"task_index"]) -> None: ...
|
||||
|
||||
global___ServerDef = ServerDef
|
||||
|
||||
@@ -392,6 +392,34 @@ class MomentumParameters(google.protobuf.message.Message):
|
||||
|
||||
global___MomentumParameters = MomentumParameters
|
||||
|
||||
@typing.final
|
||||
class LionParameters(google.protobuf.message.Message):
|
||||
"""https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Lion
|
||||
momenta(new) = beta2 * momenta(old) + (1 - beta2) * grad
|
||||
momenta_t = beta1 * momenta(old) + (1 - beta1) * grad
|
||||
var(new) = var(old) - lr * sign(momenta_t)
|
||||
Algorithm described in https://arxiv.org/abs/2302.06675.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
BETA1_FIELD_NUMBER: builtins.int
|
||||
BETA2_FIELD_NUMBER: builtins.int
|
||||
USE_NON_LAZY_LION_FIELD_NUMBER: builtins.int
|
||||
beta1: builtins.float
|
||||
beta2: builtins.float
|
||||
use_non_lazy_lion: builtins.bool
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
beta1: builtins.float | None = ...,
|
||||
beta2: builtins.float | None = ...,
|
||||
use_non_lazy_lion: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["beta1", b"beta1", "beta2", b"beta2", "use_non_lazy_lion", b"use_non_lazy_lion"]) -> None: ...
|
||||
|
||||
global___LionParameters = LionParameters
|
||||
|
||||
@typing.final
|
||||
class RmsPropParameters(google.protobuf.message.Message):
|
||||
"""https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop
|
||||
@@ -898,6 +926,7 @@ class OptimizationParameters(google.protobuf.message.Message):
|
||||
FTRL_FIELD_NUMBER: builtins.int
|
||||
ADAM_FIELD_NUMBER: builtins.int
|
||||
MOMENTUM_FIELD_NUMBER: builtins.int
|
||||
LION_FIELD_NUMBER: builtins.int
|
||||
RMS_PROP_FIELD_NUMBER: builtins.int
|
||||
CENTERED_RMS_PROP_FIELD_NUMBER: builtins.int
|
||||
MDL_ADAGRAD_LIGHT_FIELD_NUMBER: builtins.int
|
||||
@@ -975,6 +1004,8 @@ class OptimizationParameters(google.protobuf.message.Message):
|
||||
@property
|
||||
def momentum(self) -> global___MomentumParameters: ...
|
||||
@property
|
||||
def lion(self) -> global___LionParameters: ...
|
||||
@property
|
||||
def rms_prop(self) -> global___RmsPropParameters: ...
|
||||
@property
|
||||
def centered_rms_prop(self) -> global___CenteredRmsPropParameters: ...
|
||||
@@ -1013,6 +1044,7 @@ class OptimizationParameters(google.protobuf.message.Message):
|
||||
ftrl: global___FtrlParameters | None = ...,
|
||||
adam: global___AdamParameters | None = ...,
|
||||
momentum: global___MomentumParameters | None = ...,
|
||||
lion: global___LionParameters | None = ...,
|
||||
rms_prop: global___RmsPropParameters | None = ...,
|
||||
centered_rms_prop: global___CenteredRmsPropParameters | None = ...,
|
||||
mdl_adagrad_light: global___MdlAdagradLightParameters | None = ...,
|
||||
@@ -1024,9 +1056,9 @@ class OptimizationParameters(google.protobuf.message.Message):
|
||||
user_defined_program: global___UserDefinedProgramParameters | None = ...,
|
||||
assign: global___AssignParameters | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["adadelta", b"adadelta", "adagrad", b"adagrad", "adagrad_momentum", b"adagrad_momentum", "adam", b"adam", "assign", b"assign", "bounded_adagrad", b"bounded_adagrad", "centered_rms_prop", b"centered_rms_prop", "clipping_limits", b"clipping_limits", "frequency_estimator", b"frequency_estimator", "ftrl", b"ftrl", "gradient_clipping_limits", b"gradient_clipping_limits", "hot_id_replication_configuration", b"hot_id_replication_configuration", "learning_rate", b"learning_rate", "mdl_adagrad_light", b"mdl_adagrad_light", "momentum", b"momentum", "online_yogi", b"online_yogi", "parameters", b"parameters", "proximal_adagrad", b"proximal_adagrad", "proximal_yogi", b"proximal_yogi", "rms_prop", b"rms_prop", "simulated_quantization", b"simulated_quantization", "stochastic_gradient_descent", b"stochastic_gradient_descent", "user_defined_program", b"user_defined_program"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["adadelta", b"adadelta", "adagrad", b"adagrad", "adagrad_momentum", b"adagrad_momentum", "adam", b"adam", "assign", b"assign", "bounded_adagrad", b"bounded_adagrad", "centered_rms_prop", b"centered_rms_prop", "clipping_limits", b"clipping_limits", "frequency_estimator", b"frequency_estimator", "ftrl", b"ftrl", "gradient_accumulation_status", b"gradient_accumulation_status", "gradient_clipping_limits", b"gradient_clipping_limits", "hot_id_replication_configuration", b"hot_id_replication_configuration", "learning_rate", b"learning_rate", "low_dimensional_packing_status", b"low_dimensional_packing_status", "mdl_adagrad_light", b"mdl_adagrad_light", "momentum", b"momentum", "multiply_weight_decay_factor_by_learning_rate", b"multiply_weight_decay_factor_by_learning_rate", "online_yogi", b"online_yogi", "parameters", b"parameters", "proximal_adagrad", b"proximal_adagrad", "proximal_yogi", b"proximal_yogi", "rms_prop", b"rms_prop", "simulated_quantization", b"simulated_quantization", "stochastic_gradient_descent", b"stochastic_gradient_descent", "user_defined_program", b"user_defined_program", "weight_decay_factor", b"weight_decay_factor"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["parameters", b"parameters"]) -> typing.Literal["adagrad", "adagrad_momentum", "bounded_adagrad", "stochastic_gradient_descent", "ftrl", "adam", "momentum", "rms_prop", "centered_rms_prop", "mdl_adagrad_light", "adadelta", "proximal_adagrad", "online_yogi", "proximal_yogi", "frequency_estimator", "user_defined_program", "assign"] | None: ...
|
||||
def HasField(self, field_name: typing.Literal["adadelta", b"adadelta", "adagrad", b"adagrad", "adagrad_momentum", b"adagrad_momentum", "adam", b"adam", "assign", b"assign", "bounded_adagrad", b"bounded_adagrad", "centered_rms_prop", b"centered_rms_prop", "clipping_limits", b"clipping_limits", "frequency_estimator", b"frequency_estimator", "ftrl", b"ftrl", "gradient_clipping_limits", b"gradient_clipping_limits", "hot_id_replication_configuration", b"hot_id_replication_configuration", "learning_rate", b"learning_rate", "lion", b"lion", "mdl_adagrad_light", b"mdl_adagrad_light", "momentum", b"momentum", "online_yogi", b"online_yogi", "parameters", b"parameters", "proximal_adagrad", b"proximal_adagrad", "proximal_yogi", b"proximal_yogi", "rms_prop", b"rms_prop", "simulated_quantization", b"simulated_quantization", "stochastic_gradient_descent", b"stochastic_gradient_descent", "user_defined_program", b"user_defined_program"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["adadelta", b"adadelta", "adagrad", b"adagrad", "adagrad_momentum", b"adagrad_momentum", "adam", b"adam", "assign", b"assign", "bounded_adagrad", b"bounded_adagrad", "centered_rms_prop", b"centered_rms_prop", "clipping_limits", b"clipping_limits", "frequency_estimator", b"frequency_estimator", "ftrl", b"ftrl", "gradient_accumulation_status", b"gradient_accumulation_status", "gradient_clipping_limits", b"gradient_clipping_limits", "hot_id_replication_configuration", b"hot_id_replication_configuration", "learning_rate", b"learning_rate", "lion", b"lion", "low_dimensional_packing_status", b"low_dimensional_packing_status", "mdl_adagrad_light", b"mdl_adagrad_light", "momentum", b"momentum", "multiply_weight_decay_factor_by_learning_rate", b"multiply_weight_decay_factor_by_learning_rate", "online_yogi", b"online_yogi", "parameters", b"parameters", "proximal_adagrad", b"proximal_adagrad", "proximal_yogi", b"proximal_yogi", "rms_prop", b"rms_prop", "simulated_quantization", b"simulated_quantization", "stochastic_gradient_descent", b"stochastic_gradient_descent", "user_defined_program", b"user_defined_program", "weight_decay_factor", b"weight_decay_factor"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["parameters", b"parameters"]) -> typing.Literal["adagrad", "adagrad_momentum", "bounded_adagrad", "stochastic_gradient_descent", "ftrl", "adam", "momentum", "lion", "rms_prop", "centered_rms_prop", "mdl_adagrad_light", "adadelta", "proximal_adagrad", "online_yogi", "proximal_yogi", "frequency_estimator", "user_defined_program", "assign"] | None: ...
|
||||
|
||||
global___OptimizationParameters = OptimizationParameters
|
||||
|
||||
|
||||
@@ -62,13 +62,17 @@ class TPUHardwareFeature(google.protobuf.message.Message):
|
||||
"""
|
||||
|
||||
EMBEDDING_FEATURE_FIELD_NUMBER: builtins.int
|
||||
NUM_EMBEDDING_DEVICES_PER_CHIP_FIELD_NUMBER: builtins.int
|
||||
embedding_feature: global___TPUHardwareFeature.EmbeddingFeature.ValueType
|
||||
num_embedding_devices_per_chip: builtins.int
|
||||
"""Number of embedding accelerator devices per chip."""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
embedding_feature: global___TPUHardwareFeature.EmbeddingFeature.ValueType | None = ...,
|
||||
num_embedding_devices_per_chip: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["embedding_feature", b"embedding_feature"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["embedding_feature", b"embedding_feature", "num_embedding_devices_per_chip", b"num_embedding_devices_per_chip"]) -> None: ...
|
||||
|
||||
global___TPUHardwareFeature = TPUHardwareFeature
|
||||
|
||||
|
||||
5
stubs/tensorflow/tensorflow/distribute/coordinator.pyi
Normal file
5
stubs/tensorflow/tensorflow/distribute/coordinator.pyi
Normal file
@@ -0,0 +1,5 @@
|
||||
from _typeshed import Incomplete
|
||||
|
||||
from .experimental.coordinator import RemoteValue as RemoteValue
|
||||
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
@@ -96,7 +96,9 @@ class RaggedFeature(NamedTuple):
|
||||
|
||||
dtype: DTypeLike
|
||||
value_key: str | None = ...
|
||||
partitions: tuple[RowSplits | RowLengths | RowStarts | RowLimits | ValueRowIds | UniformRowLength, ...] = ... # type: ignore[name-defined]
|
||||
partitions: tuple[ # type: ignore[name-defined]
|
||||
RowSplits | RowLengths | RowStarts | RowLimits | ValueRowIds | UniformRowLength, ...
|
||||
] = ...
|
||||
row_splits_dtype: DTypeLike = ...
|
||||
validate: bool = ...
|
||||
|
||||
|
||||
@@ -11,9 +11,7 @@ _Activation: TypeAlias = str | None | Callable[[Tensor], Tensor] | dict[str, Any
|
||||
# Ints are not allowed.
|
||||
_ActivationInput: TypeAlias = Tensor | FloatDataSequence | FloatArray | np.number[Any] | float
|
||||
|
||||
def deserialize(
|
||||
name: str, custom_objects: dict[str, Callable[..., Any]] | None = None, use_legacy_format: bool = False
|
||||
) -> Callable[..., Any]: ...
|
||||
def deserialize(config: dict[str, Any], custom_objects: dict[str, Callable[..., Any]] | None = None) -> Callable[..., Any]: ...
|
||||
def elu(x: _ActivationInput, alpha: FloatTensorCompatible | FloatDataSequence = 1.0) -> Tensor: ...
|
||||
def exponential(x: _ActivationInput) -> Tensor: ...
|
||||
def gelu(x: _ActivationInput, approximate: bool = False) -> Tensor: ...
|
||||
@@ -23,12 +21,12 @@ def linear(x: _ActivationInput) -> Tensor: ...
|
||||
def mish(x: _ActivationInput) -> Tensor: ...
|
||||
def relu(
|
||||
x: _ActivationInput,
|
||||
alpha: FloatTensorCompatible = 0.0,
|
||||
negative_slope: FloatTensorCompatible = 0.0,
|
||||
max_value: FloatTensorCompatible | FloatDataSequence | None = None,
|
||||
threshold: FloatTensorCompatible | FloatDataSequence = 0.0,
|
||||
) -> Tensor: ...
|
||||
def selu(x: _ActivationInput) -> Tensor: ...
|
||||
def serialize(activation: Callable[..., Any], use_legacy_format: bool = False) -> str | dict[str, Any]: ...
|
||||
def serialize(activation: Callable[..., Any]) -> str | dict[str, Any]: ...
|
||||
def sigmoid(x: _ActivationInput) -> Tensor: ...
|
||||
def softmax(x: Tensor, axis: Integer = -1) -> Tensor: ...
|
||||
def softplus(x: _ActivationInput) -> Tensor: ...
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
@@ -6,37 +6,33 @@ import tensorflow as tf
|
||||
from requests.api import _HeadersMapping
|
||||
from tensorflow.keras import Model
|
||||
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
|
||||
from tensorflow.saved_model import SaveOptions
|
||||
from tensorflow.train import CheckpointOptions
|
||||
|
||||
_Logs: TypeAlias = Mapping[str, Any] | None | Any
|
||||
|
||||
class Callback:
|
||||
model: Model[Any, Any]
|
||||
params: dict[str, Any]
|
||||
def set_model(self, model: Model[Any, Any]) -> None: ...
|
||||
def set_params(self, params: dict[str, Any]) -> None: ...
|
||||
def set_model(self, model: Model[Any, Any]) -> None: ...
|
||||
@property
|
||||
def model(self) -> Model[Any, Any]: ...
|
||||
def on_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_epoch_begin(self, epoch: int, logs: _Logs = None) -> None: ...
|
||||
def on_epoch_end(self, epoch: int, logs: _Logs = None) -> None: ...
|
||||
def on_predict_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_predict_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_predict_begin(self, logs: _Logs = None) -> None: ...
|
||||
def on_predict_end(self, logs: _Logs = None) -> None: ...
|
||||
def on_test_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_test_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_test_begin(self, logs: _Logs = None) -> None: ...
|
||||
def on_test_end(self, logs: _Logs = None) -> None: ...
|
||||
def on_train_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_train_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_test_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_test_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_predict_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_predict_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_train_begin(self, logs: _Logs = None) -> None: ...
|
||||
def on_train_end(self, logs: _Logs = None) -> None: ...
|
||||
def on_test_begin(self, logs: _Logs = None) -> None: ...
|
||||
def on_test_end(self, logs: _Logs = None) -> None: ...
|
||||
def on_predict_begin(self, logs: _Logs = None) -> None: ...
|
||||
def on_predict_end(self, logs: _Logs = None) -> None: ...
|
||||
|
||||
# A CallbackList has exact same api as a callback, but does not actually subclass it.
|
||||
class CallbackList:
|
||||
model: Model[Any, Any]
|
||||
params: dict[str, Any]
|
||||
class CallbackList(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
callbacks: Sequence[Callback] | None = None,
|
||||
@@ -46,32 +42,28 @@ class CallbackList:
|
||||
model: Model[Any, Any] | None = None,
|
||||
**params: Any,
|
||||
) -> None: ...
|
||||
def set_model(self, model: Model[Any, Any]) -> None: ...
|
||||
def append(self, callback: Callback) -> None: ...
|
||||
def set_params(self, params: dict[str, Any]) -> None: ...
|
||||
def on_batch_begin(self, batch: int, logs: _Logs | None = None) -> None: ...
|
||||
def on_batch_end(self, batch: int, logs: _Logs | None = None) -> None: ...
|
||||
def on_epoch_begin(self, epoch: int, logs: _Logs | None = None) -> None: ...
|
||||
def on_epoch_end(self, epoch: int, logs: _Logs | None = None) -> None: ...
|
||||
def on_predict_batch_begin(self, batch: int, logs: _Logs | None = None) -> None: ...
|
||||
def on_predict_batch_end(self, batch: int, logs: _Logs | None = None) -> None: ...
|
||||
def on_predict_begin(self, logs: _Logs | None = None) -> None: ...
|
||||
def on_predict_end(self, logs: _Logs | None = None) -> None: ...
|
||||
def on_test_batch_begin(self, batch: int, logs: _Logs | None = None) -> None: ...
|
||||
def on_test_batch_end(self, batch: int, logs: _Logs | None = None) -> None: ...
|
||||
def on_test_begin(self, logs: _Logs | None = None) -> None: ...
|
||||
def on_test_end(self, logs: _Logs | None = None) -> None: ...
|
||||
def on_train_batch_begin(self, batch: int, logs: _Logs | None = None) -> None: ...
|
||||
def on_train_batch_end(self, batch: int, logs: _Logs | None = None) -> None: ...
|
||||
def on_train_begin(self, logs: _Logs | None = None) -> None: ...
|
||||
def on_train_end(self, logs: _Logs | None = None) -> None: ...
|
||||
def set_model(self, model: Model[Any, Any]) -> None: ...
|
||||
def on_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_epoch_begin(self, epoch: int, logs: _Logs = None) -> None: ...
|
||||
def on_epoch_end(self, epoch: int, logs: _Logs = None) -> None: ...
|
||||
def on_train_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_train_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_test_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_test_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_predict_batch_begin(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_predict_batch_end(self, batch: int, logs: _Logs = None) -> None: ...
|
||||
def on_train_begin(self, logs: _Logs = None) -> None: ...
|
||||
def on_train_end(self, logs: _Logs = None) -> None: ...
|
||||
def on_test_begin(self, logs: _Logs = None) -> None: ...
|
||||
def on_test_end(self, logs: _Logs = None) -> None: ...
|
||||
def on_predict_begin(self, logs: _Logs = None) -> None: ...
|
||||
def on_predict_end(self, logs: _Logs = None) -> None: ...
|
||||
|
||||
class BackupAndRestore(Callback):
|
||||
def __init__(
|
||||
self, backup_dir: str, save_freq: str = "epoch", delete_checkpoint: bool = True, save_before_preemption: bool = False
|
||||
) -> None: ...
|
||||
|
||||
class BaseLogger(Callback):
|
||||
def __init__(self, stateful_metrics: Iterable[str] | None = None) -> None: ...
|
||||
def __init__(self, backup_dir: str, save_freq: str = "epoch", delete_checkpoint: bool = True) -> None: ...
|
||||
|
||||
class CSVLogger(Callback):
|
||||
def __init__(self, filename: str, separator: str = ",", append: bool = False) -> None: ...
|
||||
@@ -98,10 +90,10 @@ class LambdaCallback(Callback):
|
||||
self,
|
||||
on_epoch_begin: Callable[[int, _Logs], object] | None = None,
|
||||
on_epoch_end: Callable[[int, _Logs], object] | None = None,
|
||||
on_batch_begin: Callable[[int, _Logs], object] | None = None,
|
||||
on_batch_end: Callable[[int, _Logs], object] | None = None,
|
||||
on_train_begin: Callable[[_Logs], object] | None = None,
|
||||
on_train_end: Callable[[_Logs], object] | None = None,
|
||||
on_train_batch_begin: Callable[[int, _Logs], object] | None = None,
|
||||
on_train_batch_end: Callable[[int, _Logs], object] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
@@ -115,7 +107,6 @@ class LearningRateScheduler(Callback):
|
||||
class ModelCheckpoint(Callback):
|
||||
monitor_op: Any
|
||||
filepath: str
|
||||
_options: CheckpointOptions | SaveOptions | None
|
||||
def __init__(
|
||||
self,
|
||||
filepath: str,
|
||||
@@ -125,16 +116,13 @@ class ModelCheckpoint(Callback):
|
||||
save_weights_only: bool = False,
|
||||
mode: Literal["auto", "min", "max"] = "auto",
|
||||
save_freq: str | int = "epoch",
|
||||
options: CheckpointOptions | SaveOptions | None = None,
|
||||
initial_value_threshold: float | None = None,
|
||||
) -> None: ...
|
||||
def _save_model(self, epoch: int, batch: int | None, logs: _Logs) -> None: ...
|
||||
|
||||
class ProgbarLogger(Callback):
|
||||
use_steps: bool
|
||||
def __init__(
|
||||
self, count_mode: Literal["steps", "samples"] = "samples", stateful_metrics: Iterable[str] | None = None
|
||||
) -> None: ...
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class ReduceLROnPlateau(Callback):
|
||||
def __init__(
|
||||
@@ -146,7 +134,7 @@ class ReduceLROnPlateau(Callback):
|
||||
mode: Literal["auto", "min", "max"] = "auto",
|
||||
min_delta: float = 1e-4,
|
||||
cooldown: int = 0,
|
||||
min_lr: float = 0,
|
||||
min_lr: float = 0.0,
|
||||
**kwargs,
|
||||
) -> None: ...
|
||||
def in_cooldown(self) -> bool: ...
|
||||
@@ -176,7 +164,6 @@ class TensorBoard(Callback):
|
||||
profile_batch: int | tuple[int, int] = 0,
|
||||
embeddings_freq: int = 0,
|
||||
embeddings_metadata: dict[str, None] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
def _write_keras_model_train_graph(self) -> None: ...
|
||||
def _stop_trace(self, batch: int | None = None) -> None: ...
|
||||
|
||||
@@ -13,7 +13,7 @@ class Initializer:
|
||||
def from_config(cls, config: dict[str, Any]) -> Self: ...
|
||||
|
||||
class Constant(Initializer):
|
||||
def __init__(self, value: TensorCompatible = 0) -> None: ...
|
||||
def __init__(self, value: TensorCompatible = 0.0) -> None: ...
|
||||
|
||||
class GlorotNormal(Initializer):
|
||||
def __init__(self, seed: int | None = None) -> None: ...
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
from _typeshed import Incomplete
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from typing import Any, Generic, Literal, TypeVar, overload
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any, Generic, Literal, TypeVar, overload, type_check_only
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import Tensor, Variable, VariableAggregation, VariableSynchronization
|
||||
from tensorflow._aliases import AnyArray, DTypeLike, TensorCompatible, TensorLike
|
||||
from tensorflow import Tensor, Variable
|
||||
from tensorflow._aliases import AnyArray, DataSequence, DTypeLike, Float, TensorCompatible, TensorLike
|
||||
from tensorflow.keras.activations import _Activation
|
||||
from tensorflow.keras.constraints import Constraint
|
||||
from tensorflow.keras.initializers import _Initializer
|
||||
from tensorflow.keras.layers.preprocessing import IntegerLookup as IntegerLookup, StringLookup as StringLookup
|
||||
from tensorflow.keras.regularizers import Regularizer, _Regularizer
|
||||
from tensorflow.python.feature_column.feature_column_v2 import DenseColumn, SequenceDenseColumn
|
||||
|
||||
_InputT = TypeVar("_InputT", contravariant=True)
|
||||
_OutputT = TypeVar("_OutputT", covariant=True)
|
||||
@@ -51,7 +49,16 @@ class Layer(tf.Module, Generic[_InputT, _OutputT]):
|
||||
@trainable.setter
|
||||
def trainable(self, value: bool) -> None: ...
|
||||
def __init__(
|
||||
self, trainable: bool = True, name: str | None = None, dtype: DTypeLike | None = None, dynamic: bool = False
|
||||
self,
|
||||
*,
|
||||
activity_regularizer: _Regularizer = None,
|
||||
trainable: bool = True,
|
||||
dtype: DTypeLike | None = None,
|
||||
autocast: bool = True,
|
||||
name: str | None = None,
|
||||
# **kwargs
|
||||
input_dim: int | None = None,
|
||||
input_shape: Any = None,
|
||||
) -> None: ...
|
||||
|
||||
# *args/**kwargs are allowed, but have obscure footguns and tensorflow documentation discourages their usage.
|
||||
@@ -69,18 +76,17 @@ class Layer(tf.Module, Generic[_InputT, _OutputT]):
|
||||
def compute_output_shape(self, input_shape: Any, /) -> Any: ...
|
||||
def add_weight(
|
||||
self,
|
||||
name: str | None = None,
|
||||
shape: Iterable[int | None] | None = None,
|
||||
dtype: DTypeLike | None = None,
|
||||
initializer: _Initializer | None = None,
|
||||
dtype: DTypeLike | None = None,
|
||||
trainable: bool = True,
|
||||
autocast: bool = True,
|
||||
regularizer: _Regularizer = None,
|
||||
trainable: bool | None = None,
|
||||
constraint: _Constraint = None,
|
||||
use_resource: bool | None = None,
|
||||
synchronization: VariableSynchronization = ...,
|
||||
aggregation: VariableAggregation = ...,
|
||||
aggregation: Literal["mean", "sum", "only_first_replica"] = "mean",
|
||||
name: str | None = None,
|
||||
) -> tf.Variable: ...
|
||||
def add_loss(self, losses: tf.Tensor | Sequence[tf.Tensor] | Callable[[], tf.Tensor]) -> None: ...
|
||||
def add_loss(self, loss: tf.Tensor | Sequence[tf.Tensor] | Callable[[], tf.Tensor]) -> None: ...
|
||||
def count_params(self) -> int: ...
|
||||
@property
|
||||
def trainable_variables(self) -> list[Variable]: ...
|
||||
@@ -112,6 +118,89 @@ _LayerDtype: TypeAlias = DTypeLike | dict[str, Any] | Any
|
||||
|
||||
_Constraint: TypeAlias = str | dict[str, Any] | Constraint | None
|
||||
|
||||
# IndexLookup is not exported by Keras
|
||||
@type_check_only
|
||||
class _IndexLookup(Layer[tf.Tensor, tf.Tensor]):
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: int | None,
|
||||
num_oov_indices: int,
|
||||
mask_token: str | None,
|
||||
oov_token: str,
|
||||
vocabulary_dtype: Literal["int64", "string"],
|
||||
vocabulary: str | None | TensorCompatible = None,
|
||||
idf_weights: TensorCompatible | None = None,
|
||||
invert: bool = False,
|
||||
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int",
|
||||
sparse: bool = False,
|
||||
pad_to_max_tokens: bool = False,
|
||||
name: str | None = None,
|
||||
*,
|
||||
# **kwargs
|
||||
vocabulary_size: int | None = None,
|
||||
has_input_vocabulary: bool = ...,
|
||||
trainable: bool | None = None,
|
||||
dtype: _LayerDtype | None = None,
|
||||
# **kwargs passed to Layer
|
||||
activity_regularizer: _Regularizer = None,
|
||||
autocast: bool = True,
|
||||
) -> None: ...
|
||||
def compute_output_signature(self, input_spec) -> tf.TensorSpec: ...
|
||||
def get_vocabulary(self, include_special_tokens: bool = True) -> list[Incomplete]: ...
|
||||
def vocabulary_size(self) -> int: ...
|
||||
|
||||
class StringLookup(_IndexLookup):
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: int | None = None,
|
||||
num_oov_indices: int = 1,
|
||||
mask_token: str | None = None,
|
||||
oov_token: str = "[UNK]",
|
||||
vocabulary: str | None | TensorCompatible = None,
|
||||
idf_weights: TensorCompatible | None = None,
|
||||
invert: bool = False,
|
||||
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int",
|
||||
pad_to_max_tokens: bool = False,
|
||||
sparse: bool = False,
|
||||
encoding: str = "utf-8",
|
||||
name: str | None = None,
|
||||
*,
|
||||
# **kwargs passed to IndexLookup
|
||||
vocabulary_size: int | None = None,
|
||||
has_input_vocabulary: bool = ...,
|
||||
trainable: bool | None = None,
|
||||
dtype: _LayerDtype | None = None,
|
||||
activity_regularizer: _Regularizer = None,
|
||||
autocast: bool = True,
|
||||
) -> None: ...
|
||||
def adapt(self, data: tf.data.Dataset[TensorLike] | AnyArray | DataSequence, steps: Float | None = None) -> None: ...
|
||||
|
||||
class IntegerLookup(_IndexLookup):
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: int | None = None,
|
||||
num_oov_indices: int = 1,
|
||||
mask_token: int | None = None,
|
||||
oov_token: int = -1,
|
||||
vocabulary: str | None | TensorCompatible = None,
|
||||
vocabulary_dtype: Literal["int64"] = "int64",
|
||||
idf_weights: TensorCompatible | None = None,
|
||||
invert: bool = False,
|
||||
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int",
|
||||
sparse: bool = False,
|
||||
pad_to_max_tokens: bool = False,
|
||||
name: str | None = None,
|
||||
*,
|
||||
# **kwargs passed to IndexLookup
|
||||
vocabulary_size: int | None = None,
|
||||
has_input_vocabulary: bool = ...,
|
||||
trainable: bool | None = None,
|
||||
dtype: _LayerDtype | None = None,
|
||||
activity_regularizer: _Regularizer = None,
|
||||
autocast: bool = True,
|
||||
) -> None: ...
|
||||
def adapt(self, data: tf.data.Dataset[TensorLike] | AnyArray | DataSequence, steps: Float | None = None) -> None: ...
|
||||
|
||||
# Layer's compute_output_shape commonly have instance as first argument name instead of self.
|
||||
# This is an artifact of actual implementation commonly uses a decorator to define it.
|
||||
# Layer.build has same weirdness sometimes. For both marked as positional only.
|
||||
@@ -128,9 +217,12 @@ class Dense(Layer[tf.Tensor, tf.Tensor]):
|
||||
activity_regularizer: _Regularizer = None,
|
||||
kernel_constraint: _Constraint = None,
|
||||
bias_constraint: _Constraint = None,
|
||||
lora_rank: int | None = None,
|
||||
*,
|
||||
# **kwargs passed to Layer
|
||||
trainable: bool = True,
|
||||
dtype: _LayerDtype | None = None,
|
||||
dynamic: bool = False,
|
||||
autocast: bool = True,
|
||||
name: str | None = None,
|
||||
) -> None: ...
|
||||
|
||||
@@ -150,9 +242,13 @@ class BatchNormalization(Layer[tf.Tensor, tf.Tensor]):
|
||||
gamma_regularizer: _Regularizer = None,
|
||||
beta_constraint: _Constraint = None,
|
||||
gamma_constraint: _Constraint = None,
|
||||
synchronized: bool = False,
|
||||
*,
|
||||
# **kwargs passed to Layer
|
||||
activity_regularizer: _Regularizer = None,
|
||||
trainable: bool = True,
|
||||
dtype: _LayerDtype | None = None,
|
||||
dynamic: bool = False,
|
||||
autocast: bool = True,
|
||||
name: str | None = None,
|
||||
) -> None: ...
|
||||
|
||||
@@ -162,9 +258,12 @@ class ReLU(Layer[tf.Tensor, tf.Tensor]):
|
||||
max_value: float | None = None,
|
||||
negative_slope: float | None = 0.0,
|
||||
threshold: float | None = 0.0,
|
||||
*,
|
||||
# **kwargs passed to Layer
|
||||
activity_regularizer: _Regularizer = None,
|
||||
trainable: bool = True,
|
||||
dtype: _LayerDtype | None = None,
|
||||
dynamic: bool = False,
|
||||
autocast: bool = True,
|
||||
name: str | None = None,
|
||||
) -> None: ...
|
||||
|
||||
@@ -174,9 +273,12 @@ class Dropout(Layer[tf.Tensor, tf.Tensor]):
|
||||
rate: float,
|
||||
noise_shape: TensorCompatible | Sequence[int | None] | None = None,
|
||||
seed: int | None = None,
|
||||
*,
|
||||
# **kwargs passed to Layer
|
||||
activity_regularizer: _Regularizer = None,
|
||||
trainable: bool = True,
|
||||
dtype: _LayerDtype | None = None,
|
||||
dynamic: bool = False,
|
||||
autocast: bool = True,
|
||||
name: str | None = None,
|
||||
) -> None: ...
|
||||
|
||||
@@ -189,10 +291,14 @@ class Embedding(Layer[tf.Tensor, tf.Tensor]):
|
||||
embeddings_regularizer: _Regularizer = None,
|
||||
embeddings_constraint: _Constraint = None,
|
||||
mask_zero: bool = False,
|
||||
lora_rank: int | None = None,
|
||||
*,
|
||||
input_length: int | None = None,
|
||||
# **kwargs passed to Layer
|
||||
activity_regularizer: _Regularizer = None,
|
||||
trainable: bool = True,
|
||||
dtype: _LayerDtype | None = None,
|
||||
dynamic: bool = False,
|
||||
autocast: bool = True,
|
||||
name: str | None = None,
|
||||
) -> None: ...
|
||||
|
||||
@@ -215,15 +321,26 @@ class Conv2D(Layer[tf.Tensor, tf.Tensor]):
|
||||
activity_regularizer: _Regularizer = None,
|
||||
kernel_constraint: _Constraint = None,
|
||||
bias_constraint: _Constraint = None,
|
||||
*,
|
||||
# **kwargs passed to Layer
|
||||
trainable: bool = True,
|
||||
dtype: _LayerDtype | None = None,
|
||||
dynamic: bool = False,
|
||||
autocast: bool = True,
|
||||
name: str | None = None,
|
||||
) -> None: ...
|
||||
|
||||
Convolution2D = Conv2D
|
||||
|
||||
class Identity(Layer[tf.Tensor, tf.Tensor]):
|
||||
def __init__(
|
||||
self, trainable: bool = True, dtype: _LayerDtype = None, dynamic: bool = False, name: str | None = None
|
||||
self,
|
||||
*,
|
||||
# **kwargs passed to Layer
|
||||
activity_regularizer: _Regularizer = None,
|
||||
trainable: bool = True,
|
||||
dtype: _LayerDtype | None = None,
|
||||
autocast: bool = True,
|
||||
name: str | None = None,
|
||||
) -> None: ...
|
||||
|
||||
class LayerNormalization(Layer[tf.Tensor, tf.Tensor]):
|
||||
@@ -233,25 +350,19 @@ class LayerNormalization(Layer[tf.Tensor, tf.Tensor]):
|
||||
epsilon: float = 0.001,
|
||||
center: bool = True,
|
||||
scale: bool = True,
|
||||
rms_scaling: bool = False,
|
||||
beta_initializer: _Initializer = "zeros",
|
||||
gamma_initializer: _Initializer = "ones",
|
||||
beta_regularizer: _Regularizer = None,
|
||||
gamma_regularizer: _Regularizer = None,
|
||||
beta_constraint: _Constraint = None,
|
||||
gamma_constraint: _Constraint = None,
|
||||
*,
|
||||
# **kwargs passed to Layer
|
||||
activity_regularizer: _Regularizer = None,
|
||||
trainable: bool = True,
|
||||
dtype: _LayerDtype | None = None,
|
||||
dynamic: bool = False,
|
||||
name: str | None = None,
|
||||
) -> None: ...
|
||||
|
||||
class DenseFeatures(Layer[Mapping[str, TensorLike], tf.Tensor]):
|
||||
def __init__(
|
||||
self,
|
||||
feature_columns: Sequence[DenseColumn | SequenceDenseColumn],
|
||||
trainable: bool = True,
|
||||
dtype: _LayerDtype = None,
|
||||
dynamic: bool = False,
|
||||
autocast: bool = True,
|
||||
name: str | None = None,
|
||||
) -> None: ...
|
||||
|
||||
@@ -265,16 +376,18 @@ class MultiHeadAttention(Layer[Any, tf.Tensor]):
|
||||
use_bias: bool = True,
|
||||
output_shape: tuple[int, ...] | None = None,
|
||||
attention_axes: tuple[int, ...] | None = None,
|
||||
kernel_initialize: _Initializer = "glorot_uniform",
|
||||
kernel_initializer: _Initializer = "glorot_uniform",
|
||||
bias_initializer: _Initializer = "zeros",
|
||||
kernel_regularizer: Regularizer | None = None,
|
||||
bias_regularizer: _Regularizer | None = None,
|
||||
activity_regularizer: _Regularizer | None = None,
|
||||
kernel_constraint: _Constraint | None = None,
|
||||
bias_constraint: _Constraint | None = None,
|
||||
*,
|
||||
# **kwargs passed to Layer
|
||||
trainable: bool = True,
|
||||
dtype: _LayerDtype | None = None,
|
||||
dynamic: bool = False,
|
||||
autocast: bool = True,
|
||||
name: str | None = None,
|
||||
) -> None: ...
|
||||
# @override
|
||||
@@ -317,9 +430,12 @@ class GaussianDropout(Layer[tf.Tensor, tf.Tensor]):
|
||||
self,
|
||||
rate: float,
|
||||
seed: int | None = None,
|
||||
*,
|
||||
# **kwargs passed to Layer
|
||||
activity_regularizer: _Regularizer = None,
|
||||
trainable: bool = True,
|
||||
dtype: _LayerDtype = None,
|
||||
dynamic: bool = False,
|
||||
dtype: _LayerDtype | None = None,
|
||||
autocast: bool = True,
|
||||
name: str | None = None,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import abc
|
||||
from typing import overload
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow._aliases import AnyArray, DataSequence, Float, Integer, TensorCompatible, TensorLike
|
||||
from tensorflow.keras.layers import Layer
|
||||
|
||||
class PreprocessingLayer(Layer[TensorLike, TensorLike], metaclass=abc.ABCMeta):
|
||||
@property
|
||||
def is_adapted(self) -> bool: ...
|
||||
@overload # type: ignore
|
||||
def __call__(self, inputs: tf.Tensor, *, training: bool = False, mask: TensorCompatible | None = None) -> tf.Tensor: ...
|
||||
@overload
|
||||
def __call__(
|
||||
self, inputs: tf.SparseTensor, *, training: bool = False, mask: TensorCompatible | None = None
|
||||
) -> tf.SparseTensor: ...
|
||||
@overload
|
||||
def __call__(
|
||||
self, inputs: tf.RaggedTensor, *, training: bool = False, mask: TensorCompatible | None = None
|
||||
) -> tf.RaggedTensor: ...
|
||||
def adapt(
|
||||
self,
|
||||
data: tf.data.Dataset[TensorLike] | AnyArray | DataSequence,
|
||||
batch_size: Integer | None = None,
|
||||
steps: Float | None = None,
|
||||
) -> None: ...
|
||||
def compile(self, run_eagerly: bool | None = None, steps_per_execution: Integer | None = None) -> None: ...
|
||||
@@ -1,36 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from tensorflow._aliases import TensorCompatible
|
||||
from tensorflow.keras.layers.preprocessing.index_lookup import _IndexLookup
|
||||
|
||||
class StringLookup(_IndexLookup):
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: int | None = None,
|
||||
num_oov_indices: int = 1,
|
||||
mask_token: str | None = None,
|
||||
oov_token: str = "[UNK]",
|
||||
vocabulary: str | None | TensorCompatible = None,
|
||||
idf_weights: TensorCompatible | None = None,
|
||||
encoding: str = "utf-8",
|
||||
invert: bool = False,
|
||||
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int",
|
||||
sparse: bool = False,
|
||||
pad_to_max_tokens: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
class IntegerLookup(_IndexLookup):
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: int | None = None,
|
||||
num_oov_indices: int = 1,
|
||||
mask_token: int | None = None,
|
||||
oov_token: int = -1,
|
||||
vocabulary: str | None | TensorCompatible = None,
|
||||
vocabulary_dtype: Literal["int64", "int32"] = "int64",
|
||||
idf_weights: TensorCompatible | None = None,
|
||||
invert: bool = False,
|
||||
output_mode: Literal["int", "count", "multi_hot", "one_hot", "tf_idf"] = "int",
|
||||
sparse: bool = False,
|
||||
pad_to_max_tokens: bool = False,
|
||||
) -> None: ...
|
||||
@@ -1,9 +0,0 @@
|
||||
from _typeshed import Incomplete
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.layers.experimental.preprocessing import PreprocessingLayer
|
||||
|
||||
class _IndexLookup(PreprocessingLayer):
|
||||
def compute_output_signature(self, input_spec) -> tf.TensorSpec: ...
|
||||
def get_vocabulary(self, include_special_tokens: bool = True) -> list[Incomplete]: ...
|
||||
def vocabulary_size(self) -> int: ...
|
||||
@@ -14,7 +14,9 @@ from tensorflow.keras.metrics import (
|
||||
class Loss(ABC):
|
||||
reduction: _ReductionValues
|
||||
name: str | None
|
||||
def __init__(self, reduction: _ReductionValues = "auto", name: str | None = None) -> None: ...
|
||||
def __init__(
|
||||
self, name: str | None = None, reduction: _ReductionValues = "sum_over_batch_size", dtype: Incomplete | None = None
|
||||
) -> None: ...
|
||||
@abstractmethod
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
@classmethod
|
||||
@@ -30,7 +32,7 @@ class BinaryCrossentropy(Loss):
|
||||
from_logits: bool = False,
|
||||
label_smoothing: float = 0.0,
|
||||
axis: int = -1,
|
||||
reduction: _ReductionValues = ...,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "binary_crossentropy",
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
@@ -44,7 +46,7 @@ class BinaryFocalCrossentropy(Loss):
|
||||
from_logits: bool = False,
|
||||
label_smoothing: float = 0.0,
|
||||
axis: int = -1,
|
||||
reduction: _ReductionValues = ...,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "binary_focal_crossentropy",
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
@@ -55,53 +57,61 @@ class CategoricalCrossentropy(Loss):
|
||||
from_logits: bool = False,
|
||||
label_smoothing: float = 0.0,
|
||||
axis: int = -1,
|
||||
reduction: _ReductionValues = ...,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str | None = "categorical_crossentropy",
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class CategoricalHinge(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "categorical_hinge") -> None: ...
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "categorical_hinge") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class CosineSimilarity(Loss):
|
||||
def __init__(self, axis: int = -1, reduction: _ReductionValues = ..., name: str | None = "cosine_similarity") -> None: ...
|
||||
def __init__(
|
||||
self, axis: int = -1, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "cosine_similarity"
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class Hinge(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "hinge") -> None: ...
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "hinge") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class Huber(Loss):
|
||||
def __init__(self, delta: float = 1.0, reduction: _ReductionValues = ..., name: str | None = "huber_loss") -> None: ...
|
||||
def __init__(
|
||||
self, delta: float = 1.0, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "huber_loss"
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class KLDivergence(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "kl_divergence") -> None: ...
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "kl_divergence") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class LogCosh(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "log_cosh") -> None: ...
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "log_cosh") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class MeanAbsoluteError(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_absolute_error") -> None: ...
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_absolute_error") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class MeanAbsolutePercentageError(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_absolute_percentage_error") -> None: ...
|
||||
def __init__(
|
||||
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_absolute_percentage_error"
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class MeanSquaredError(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_squared_error") -> None: ...
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_squared_error") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class MeanSquaredLogarithmicError(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_squared_logarithmic_error") -> None: ...
|
||||
def __init__(
|
||||
self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "mean_squared_logarithmic_error"
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class Poisson(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "poisson") -> None: ...
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "poisson") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class SparseCategoricalCrossentropy(Loss):
|
||||
@@ -109,13 +119,13 @@ class SparseCategoricalCrossentropy(Loss):
|
||||
self,
|
||||
from_logits: bool = False,
|
||||
ignore_class: int | None = None,
|
||||
reduction: _ReductionValues = ...,
|
||||
reduction: _ReductionValues = "sum_over_batch_size",
|
||||
name: str = "sparse_categorical_crossentropy",
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class SquaredHinge(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "squared_hinge") -> None: ...
|
||||
def __init__(self, reduction: _ReductionValues = "sum_over_batch_size", name: str | None = "squared_hinge") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class Reduction:
|
||||
@@ -133,10 +143,8 @@ _ReductionValues: TypeAlias = Literal["auto", "none", "sum", "sum_over_batch_siz
|
||||
def categorical_hinge(y_true: TensorCompatible, y_pred: TensorCompatible) -> Tensor: ...
|
||||
def huber(y_true: TensorCompatible, y_pred: TensorCompatible, delta: float = 1.0) -> Tensor: ...
|
||||
def log_cosh(y_true: TensorCompatible, y_pred: TensorCompatible) -> Tensor: ...
|
||||
def deserialize(
|
||||
name: str | dict[str, Any], custom_objects: dict[str, Any] | None = None, use_legacy_format: bool = False
|
||||
) -> Loss: ...
|
||||
def serialize(loss: KerasSerializable, use_legacy_format: bool = False) -> dict[str, Any]: ...
|
||||
def deserialize(name: str | dict[str, Any], custom_objects: dict[str, Any] | None = None) -> Loss: ...
|
||||
def serialize(loss: KerasSerializable) -> dict[str, Any]: ...
|
||||
|
||||
_FuncT = TypeVar("_FuncT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@@ -12,9 +12,8 @@ from tensorflow.keras.initializers import _Initializer
|
||||
_Output: TypeAlias = Tensor | dict[str, Tensor]
|
||||
|
||||
class Metric(tf.keras.layers.Layer[tf.Tensor, tf.Tensor], metaclass=ABCMeta):
|
||||
def __init__(self, name: str | None = None, dtype: DTypeLike | None = None) -> None: ...
|
||||
def __init__(self, dtype: DTypeLike | None = None, name: str | None = None) -> None: ...
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Self: ...
|
||||
def merge_state(self, metrics: Iterable[Self]) -> list[Operation]: ...
|
||||
def reset_state(self) -> None: ...
|
||||
@abstractmethod
|
||||
def update_state(
|
||||
@@ -110,7 +109,7 @@ class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
|
||||
self, k: int = 5, name: str | None = "sparse_top_k_categorical_accuracy", dtype: DTypeLike | None = None
|
||||
) -> None: ...
|
||||
|
||||
def serialize(metric: KerasSerializable, use_legacy_format: bool = False) -> dict[str, Any]: ...
|
||||
def serialize(metric: KerasSerializable) -> dict[str, Any]: ...
|
||||
def binary_crossentropy(
|
||||
y_true: TensorCompatible, y_pred: TensorCompatible, from_logits: bool = False, label_smoothing: float = 0.0, axis: int = -1
|
||||
) -> Tensor: ...
|
||||
|
||||
@@ -2,11 +2,10 @@ from _typeshed import Incomplete
|
||||
from collections.abc import Callable, Container, Iterator
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from typing_extensions import Self, TypeAlias, deprecated
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import tensorflow
|
||||
import tensorflow as tf
|
||||
from tensorflow import Variable
|
||||
from tensorflow._aliases import ContainerGeneric, ShapeLike, TensorCompatible
|
||||
@@ -16,18 +15,22 @@ from tensorflow.keras.optimizers import Optimizer
|
||||
_Loss: TypeAlias = str | tf.keras.losses.Loss | Callable[[TensorCompatible, TensorCompatible], tf.Tensor]
|
||||
_Metric: TypeAlias = str | tf.keras.metrics.Metric | Callable[[TensorCompatible, TensorCompatible], tf.Tensor] | None
|
||||
|
||||
class Model(Layer[_InputT, _OutputT], tf.Module):
|
||||
# Missing keras.src.backend.tensorflow.trainer.TensorFlowTrainer as a base class, which is not exposed by tensorflow
|
||||
class Model(Layer[_InputT, _OutputT]):
|
||||
_train_counter: tf.Variable
|
||||
_test_counter: tf.Variable
|
||||
optimizer: Optimizer | None
|
||||
loss: tf.keras.losses.Loss | dict[str, tf.keras.losses.Loss]
|
||||
# This is actually TensorFlowTrainer.loss
|
||||
@deprecated("Instead, use `model.compute_loss(x, y, y_pred, sample_weight)`.")
|
||||
def loss(
|
||||
self, y: TensorCompatible | None, y_pred: TensorCompatible | None, sample_weight: Incomplete | None = None
|
||||
) -> tf.Tensor | None: ...
|
||||
stop_training: bool
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Model[_InputT, _OutputT]: ...
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
def __setattr__(self, name: str, value: Any) -> None: ...
|
||||
def __reduce__(self): ...
|
||||
def __deepcopy__(self, memo): ...
|
||||
def build(self, input_shape: ShapeLike) -> None: ...
|
||||
def __call__(self, inputs: _InputT, *, training: bool = False, mask: TensorCompatible | None = None) -> _OutputT: ...
|
||||
def call(self, inputs: _InputT, training: bool | None = None, mask: TensorCompatible | None = None) -> _OutputT: ...
|
||||
@@ -36,14 +39,13 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
|
||||
self,
|
||||
optimizer: Optimizer | str = "rmsprop",
|
||||
loss: ContainerGeneric[_Loss] | None = None,
|
||||
metrics: ContainerGeneric[_Metric] | None = None,
|
||||
loss_weights: ContainerGeneric[float] | None = None,
|
||||
metrics: ContainerGeneric[_Metric] | None = None,
|
||||
weighted_metrics: ContainerGeneric[_Metric] | None = None,
|
||||
run_eagerly: bool | None = None,
|
||||
steps_per_execution: int | Literal["auto"] | None = None,
|
||||
jit_compile: bool | None = None,
|
||||
pss_evaluation_shards: int | Literal["auto"] = 0,
|
||||
**kwargs: Any,
|
||||
run_eagerly: bool = False,
|
||||
steps_per_execution: int | Literal["auto"] = 1,
|
||||
jit_compile: bool | Literal["auto"] = "auto",
|
||||
auto_scale_loss: bool | None = True,
|
||||
) -> None: ...
|
||||
@property
|
||||
def metrics(self) -> list[Incomplete]: ...
|
||||
@@ -54,10 +56,6 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
|
||||
@property
|
||||
def run_eagerly(self) -> bool: ...
|
||||
@property
|
||||
def autotune_steps_per_execution(self) -> bool: ...
|
||||
@property
|
||||
def steps_per_execution(self) -> int | None: ... # Returns None for a non-compiled model.
|
||||
@property
|
||||
def jit_compile(self) -> bool: ...
|
||||
@property
|
||||
def distribute_reduction_method(self) -> Incomplete | Literal["auto"]: ...
|
||||
@@ -70,7 +68,7 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
|
||||
sample_weight: Incomplete | None = None,
|
||||
) -> tf.Tensor | None: ...
|
||||
def compute_metrics(
|
||||
self, x: TensorCompatible, y: TensorCompatible, y_pred: TensorCompatible, sample_weight
|
||||
self, x: TensorCompatible, y: TensorCompatible, y_pred: TensorCompatible, sample_weight: Incomplete | None = None
|
||||
) -> dict[str, float]: ...
|
||||
def get_metrics_result(self) -> dict[str, float]: ...
|
||||
def make_train_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], dict[str, float]]: ...
|
||||
@@ -92,9 +90,6 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
|
||||
validation_steps: int | None = None,
|
||||
validation_batch_size: int | None = None,
|
||||
validation_freq: int | Container[int] = 1,
|
||||
max_queue_size: int = 10,
|
||||
workers: int = 1,
|
||||
use_multiprocessing: bool = False,
|
||||
) -> tf.keras.callbacks.History: ...
|
||||
def test_step(self, data: TensorCompatible) -> dict[str, float]: ...
|
||||
def make_test_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], dict[str, float]]: ...
|
||||
@@ -107,9 +102,6 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
|
||||
sample_weight: npt.NDArray[np.float_] | None = None,
|
||||
steps: int | None = None,
|
||||
callbacks: list[tf.keras.callbacks.Callback] | None = None,
|
||||
max_queue_size: int = 10,
|
||||
workers: int = 1,
|
||||
use_multiprocessing: bool = False,
|
||||
return_dict: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> float | list[float]: ...
|
||||
@@ -122,9 +114,6 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
|
||||
verbose: Literal["auto", 0, 1, 2] = "auto",
|
||||
steps: int | None = None,
|
||||
callbacks: list[tf.keras.callbacks.Callback] | None = None,
|
||||
max_queue_size: int = 10,
|
||||
workers: int = 1,
|
||||
use_multiprocessing: bool = False,
|
||||
) -> _OutputT: ...
|
||||
def reset_metrics(self) -> None: ...
|
||||
def train_on_batch(
|
||||
@@ -133,7 +122,6 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
|
||||
y: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete] | None = None,
|
||||
sample_weight: npt.NDArray[np.float_] | None = None,
|
||||
class_weight: dict[int, float] | None = None,
|
||||
reset_metrics: bool = True,
|
||||
return_dict: bool = False,
|
||||
) -> float | list[float]: ...
|
||||
def test_on_batch(
|
||||
@@ -141,77 +129,22 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
|
||||
x: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete],
|
||||
y: TensorCompatible | dict[str, TensorCompatible] | tf.data.Dataset[Incomplete] | None = None,
|
||||
sample_weight: npt.NDArray[np.float_] | None = None,
|
||||
reset_metrics: bool = True,
|
||||
return_dict: bool = False,
|
||||
) -> float | list[float]: ...
|
||||
def predict_on_batch(self, x: Iterator[_InputT]) -> npt.NDArray[Incomplete]: ...
|
||||
def fit_generator(
|
||||
self,
|
||||
generator: Iterator[Incomplete],
|
||||
steps_per_epoch: int | None = None,
|
||||
epochs: int = 1,
|
||||
verbose: Literal["auto", 0, 1, 2] = 1,
|
||||
callbacks: list[tf.keras.callbacks.Callback] | None = None,
|
||||
validation_data: TensorCompatible | tf.data.Dataset[Any] | None = None,
|
||||
validation_steps: int | None = None,
|
||||
validation_freq: int | Container[int] = 1,
|
||||
class_weight: dict[int, float] | None = None,
|
||||
max_queue_size: int = 10,
|
||||
workers: int = 1,
|
||||
use_multiprocessing: bool = False,
|
||||
shuffle: bool = True,
|
||||
initial_epoch: int = 0,
|
||||
) -> tf.keras.callbacks.History: ...
|
||||
def evaluate_generator(
|
||||
self,
|
||||
generator: Iterator[Incomplete],
|
||||
steps: int | None = None,
|
||||
callbacks: list[tf.keras.callbacks.Callback] | None = None,
|
||||
max_queue_size: int = 10,
|
||||
workers: int = 1,
|
||||
use_multiprocessing: bool = False,
|
||||
verbose: Literal["auto", 0, 1, 2] = 0,
|
||||
) -> float | list[float]: ...
|
||||
def predict_generator(
|
||||
self,
|
||||
generator: Iterator[Incomplete],
|
||||
steps: int | None = None,
|
||||
callbacks: list[tf.keras.callbacks.Callback] | None = None,
|
||||
max_queue_size: int = 10,
|
||||
workers: int = 1,
|
||||
use_multiprocessing: bool = False,
|
||||
verbose: Literal["auto", 0, 1, 2] = 0,
|
||||
) -> _OutputT: ...
|
||||
@property
|
||||
def trainable_weights(self) -> list[Variable]: ...
|
||||
@property
|
||||
def non_trainable_weights(self) -> list[Variable]: ...
|
||||
def get_weights(self): ...
|
||||
def save(
|
||||
self, filepath: str | Path, overwrite: bool = True, save_format: Literal["keras", "tf", "h5"] | None = None, **kwargs: Any
|
||||
) -> None: ...
|
||||
def save_weights(
|
||||
self,
|
||||
filepath: str | Path,
|
||||
overwrite: bool = True,
|
||||
save_format: Literal["tf", "h5"] | None = None,
|
||||
options: tf.train.CheckpointOptions | None = None,
|
||||
) -> None: ...
|
||||
def load_weights(
|
||||
self,
|
||||
filepath: str | Path,
|
||||
skip_mismatch: bool = False,
|
||||
by_name: bool = False,
|
||||
options: None | tensorflow.train.CheckpointOptions = None,
|
||||
) -> None: ...
|
||||
def save(self, filepath: str | Path, overwrite: bool = True) -> None: ...
|
||||
def save_weights(self, filepath: str | Path, overwrite: bool = True) -> None: ...
|
||||
# kwargs are from keras.saving.saving_api.load_weights
|
||||
def load_weights(self, filepath: str | Path, skip_mismatch: bool = False, *, by_name: bool = False) -> None: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any], custom_objects: Incomplete | None = None) -> Self: ...
|
||||
def to_json(self, **kwargs: Any) -> str: ...
|
||||
def to_yaml(self, **kwargs: Any) -> str: ...
|
||||
def reset_states(self) -> None: ...
|
||||
@property
|
||||
def state_updates(self) -> list[Incomplete]: ...
|
||||
@property
|
||||
def weights(self) -> list[Variable]: ...
|
||||
def summary(
|
||||
@@ -226,10 +159,8 @@ class Model(Layer[_InputT, _OutputT], tf.Module):
|
||||
@property
|
||||
def layers(self) -> list[Layer[Incomplete, Incomplete]]: ...
|
||||
def get_layer(self, name: str | None = None, index: int | None = None) -> Layer[Incomplete, Incomplete]: ...
|
||||
def get_weight_paths(self) -> dict[str, tf.Variable]: ...
|
||||
def get_compile_config(self) -> dict[str, Any]: ...
|
||||
def compile_from_config(self, config: dict[str, Any]) -> Self: ...
|
||||
def export(self, filepath: str | Path) -> None: ...
|
||||
def save_spec(self, dynamic_batch: bool = True) -> tuple[tuple[tf.TensorSpec, ...], dict[str, tf.TensorSpec]] | None: ...
|
||||
def export(self, filepath: str | Path, format: str = "tf_saved_model") -> None: ...
|
||||
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
from _typeshed import Incomplete
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow._aliases import Gradients
|
||||
from tensorflow.keras.optimizers import schedules as schedules
|
||||
from tensorflow.python.trackable.base import Trackable
|
||||
|
||||
_Initializer: TypeAlias = str | Callable[[], tf.Tensor] | dict[str, Any]
|
||||
_Shape: TypeAlias = tf.TensorShape | Iterable[int | None]
|
||||
_Dtype: TypeAlias = tf.DType | str | None
|
||||
_LearningRate: TypeAlias = float | tf.Tensor | schedules.LearningRateSchedule | Callable[[], float | tf.Tensor]
|
||||
_GradientAggregator: TypeAlias = Callable[[list[tuple[Gradients, tf.Variable]]], list[tuple[Gradients, tf.Variable]]] | None
|
||||
_GradientTransformer: TypeAlias = (
|
||||
@@ -33,49 +29,6 @@ class Optimizer(Trackable):
|
||||
gradient_transformers: _GradientTransformer = None,
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
def _create_all_weights(self, var_list: Iterable[tf.Variable]) -> None: ...
|
||||
@property
|
||||
def iterations(self) -> tf.Variable: ...
|
||||
@iterations.setter
|
||||
def iterations(self, variable: tf.Variable) -> None: ...
|
||||
def add_slot(
|
||||
self, var: tf.Variable, slot_name: str, initializer: _Initializer = "zeros", shape: tf.TensorShape | None = None
|
||||
) -> tf.Variable: ...
|
||||
def add_weight(
|
||||
self,
|
||||
name: str,
|
||||
shape: _Shape,
|
||||
dtype: _Dtype = None,
|
||||
initializer: _Initializer = "zeros",
|
||||
trainable: None | bool = None,
|
||||
synchronization: tf.VariableSynchronization = ...,
|
||||
aggregation: tf.VariableAggregation = ...,
|
||||
) -> tf.Variable: ...
|
||||
def apply_gradients(
|
||||
self,
|
||||
grads_and_vars: Iterable[tuple[Gradients, tf.Variable]],
|
||||
name: str | None = None,
|
||||
experimental_aggregate_gradients: bool = True,
|
||||
) -> tf.Operation | None: ...
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any], custom_objects: dict[str, type] | None = None) -> Self: ...
|
||||
# Missing ABC is intentional as class is not abstract at runtime.
|
||||
@abstractmethod
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
def get_slot(self, var: tf.Variable, slot_name: str) -> tf.Variable: ...
|
||||
def get_slot_names(self) -> list[str]: ...
|
||||
def get_gradients(self, loss: tf.Tensor, params: list[tf.Variable]) -> list[Gradients]: ...
|
||||
def minimize(
|
||||
self,
|
||||
loss: tf.Tensor | Callable[[], tf.Tensor],
|
||||
var_list: list[tf.Variable] | tuple[tf.Variable, ...] | Callable[[], list[tf.Variable] | tuple[tf.Variable, ...]],
|
||||
grad_loss: tf.Tensor | None = None,
|
||||
name: str | None = None,
|
||||
tape: tf.GradientTape | None = None,
|
||||
) -> tf.Operation: ...
|
||||
def variables(self) -> list[tf.Variable]: ...
|
||||
@property
|
||||
def weights(self) -> list[tf.Variable]: ...
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(
|
||||
@@ -88,7 +41,6 @@ class Adam(Optimizer):
|
||||
name: str = "Adam",
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
|
||||
class Adagrad(Optimizer):
|
||||
_initial_accumulator_value: float
|
||||
@@ -101,12 +53,10 @@ class Adagrad(Optimizer):
|
||||
name: str = "Adagrad",
|
||||
**kwargs: Any,
|
||||
) -> None: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(
|
||||
self, learning_rate: _LearningRate = 0.01, momentum: float = 0.0, nesterov: bool = False, name: str = "SGD", **kwargs: Any
|
||||
) -> None: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
|
||||
@@ -19,7 +19,7 @@ class PiecewiseConstantDecay(LearningRateSchedule):
|
||||
self,
|
||||
boundaries: Sequence[tf.Tensor] | Sequence[float],
|
||||
values: Sequence[float] | Sequence[tf.Tensor],
|
||||
name: str | None = None,
|
||||
name: str = "PiecewiseConstant",
|
||||
) -> None: ...
|
||||
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
@@ -33,7 +33,7 @@ class InverseTimeDecay(LearningRateSchedule):
|
||||
decay_steps: int,
|
||||
decay_rate: float,
|
||||
staircase: bool = False,
|
||||
name: str | None = None,
|
||||
name: str = "InverseTimeDecay",
|
||||
) -> None: ...
|
||||
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
@@ -48,7 +48,7 @@ class PolynomialDecay(LearningRateSchedule):
|
||||
end_learning_rate: float | tf.Tensor = 0.0001,
|
||||
power: float = 1.0,
|
||||
cycle: bool = False,
|
||||
name: str | None = None,
|
||||
name: str = "PolynomialDecay",
|
||||
) -> None: ...
|
||||
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
@@ -61,7 +61,7 @@ class CosineDecay(LearningRateSchedule):
|
||||
initial_learning_rate: float | tf.Tensor,
|
||||
decay_steps: int,
|
||||
alpha: float | tf.Tensor = 0.0,
|
||||
name: str | None = None,
|
||||
name: str = "CosineDecay",
|
||||
warmup_target: int | tf.Tensor | None = None, # float32 or float64 Tensor
|
||||
warmup_steps: int | tf.Tensor = 0, # int32 or int64 Tensor
|
||||
) -> None: ...
|
||||
@@ -78,7 +78,7 @@ class CosineDecayRestarts(LearningRateSchedule):
|
||||
t_mul: float | tf.Tensor = 2.0,
|
||||
m_mul: float | tf.Tensor = 1.0,
|
||||
alpha: float | tf.Tensor = 0.0,
|
||||
name: str | None = None,
|
||||
name: str = "SGDRDecay",
|
||||
) -> None: ...
|
||||
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
@@ -92,14 +92,12 @@ class ExponentialDecay(LearningRateSchedule):
|
||||
decay_steps: int | tf.Tensor,
|
||||
decay_rate: float | tf.Tensor,
|
||||
staircase: bool = False,
|
||||
name: str | None = None,
|
||||
name: str = "ExponentialDecay",
|
||||
) -> None: ...
|
||||
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> Self: ...
|
||||
|
||||
def deserialize(
|
||||
config: dict[str, Any], custom_objects: dict[str, type] | None = None, use_legacy_format: bool = False
|
||||
) -> LearningRateSchedule: ...
|
||||
def serialize(learning_rate_schedule: LearningRateSchedule, use_legacy_format: bool = False) -> dict[str, Any]: ...
|
||||
def deserialize(config: dict[str, Any], custom_objects: dict[str, type] | None = None) -> LearningRateSchedule: ...
|
||||
def serialize(learning_rate_schedule: LearningRateSchedule) -> dict[str, Any]: ...
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from _typeshed import Incomplete
|
||||
from collections.abc import Callable
|
||||
from typing import Any, overload
|
||||
from typing_extensions import Self, TypeAlias
|
||||
@@ -18,4 +19,4 @@ def get(identifier: None) -> None: ...
|
||||
def get(identifier: str | dict[str, Any] | Regularizer) -> Regularizer: ...
|
||||
@overload
|
||||
def get(identifier: Callable[[Tensor], Tensor]) -> Callable[[Tensor], Tensor]: ...
|
||||
def __getattr__(name: str) -> Any: ...
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
|
||||
@@ -19,6 +19,8 @@ def matmul(
|
||||
a_is_sparse: _bool = False,
|
||||
b_is_sparse: _bool = False,
|
||||
output_type: DTypeLike | None = None,
|
||||
grad_a: _bool = False,
|
||||
grad_b: _bool = False,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
@@ -32,6 +34,8 @@ def matmul(
|
||||
a_is_sparse: _bool = False,
|
||||
b_is_sparse: _bool = False,
|
||||
output_type: DTypeLike | None = None,
|
||||
grad_a: _bool = False,
|
||||
grad_b: _bool = False,
|
||||
name: str | None = None,
|
||||
) -> RaggedTensor: ...
|
||||
def set_diag(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from _typeshed import Incomplete
|
||||
from collections.abc import Mapping, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, Literal, TypeVar
|
||||
@@ -6,7 +7,7 @@ from typing_extensions import ParamSpec, TypeAlias
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.training.tracking.autotrackable import AutoTrackable
|
||||
from tensorflow.saved_model.experimental import VariablePolicy
|
||||
from tensorflow.types.experimental import ConcreteFunction, GenericFunction
|
||||
from tensorflow.types.experimental import ConcreteFunction, PolymorphicFunction
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R", covariant=True)
|
||||
@@ -39,15 +40,17 @@ class SaveOptions:
|
||||
"namespace_whitelist",
|
||||
"save_debug_info",
|
||||
"function_aliases",
|
||||
"experimental_debug_stripper",
|
||||
"experimental_io_device",
|
||||
"experimental_variable_policy",
|
||||
"experimental_custom_gradients",
|
||||
"experimental_image_format",
|
||||
"experimental_skip_saver",
|
||||
"experimental_sharding_callback",
|
||||
)
|
||||
namespace_whitelist: list[str]
|
||||
save_debug_info: bool
|
||||
function_aliases: dict[str, tf.types.experimental.GenericFunction[..., object]]
|
||||
function_aliases: dict[str, PolymorphicFunction[..., object]]
|
||||
experimental_io_device: str
|
||||
experimental_variable_policy: VariablePolicy
|
||||
experimental_custom_gradients: bool
|
||||
@@ -57,12 +60,14 @@ class SaveOptions:
|
||||
self,
|
||||
namespace_whitelist: list[str] | None = None,
|
||||
save_debug_info: bool = False,
|
||||
function_aliases: Mapping[str, tf.types.experimental.GenericFunction[..., object]] | None = None,
|
||||
function_aliases: Mapping[str, PolymorphicFunction[..., object]] | None = None,
|
||||
experimental_debug_stripper: bool = False,
|
||||
experimental_io_device: str | None = None,
|
||||
experimental_variable_policy: str | VariablePolicy | None = None,
|
||||
experimental_custom_gradients: bool = True,
|
||||
experimental_image_format: bool = False,
|
||||
experimental_skip_saver: bool = False,
|
||||
experimental_sharding_callback: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
|
||||
def contains_saved_model(export_dir: str | Path) -> bool: ...
|
||||
@@ -80,7 +85,7 @@ def load(
|
||||
export_dir: str, tags: str | Sequence[str] | None = None, options: LoadOptions | None = None
|
||||
) -> _LoadedModel[..., Any]: ...
|
||||
|
||||
_TF_Function: TypeAlias = ConcreteFunction[..., object] | GenericFunction[..., object]
|
||||
_TF_Function: TypeAlias = ConcreteFunction[..., object] | PolymorphicFunction[..., object]
|
||||
|
||||
def save(
|
||||
obj: tf.Module,
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing_extensions import Self
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow._aliases import FloatArray, IntArray
|
||||
from tensorflow.core.framework.graph_pb2 import GraphDef
|
||||
from tensorflow.experimental.dtensor import Mesh
|
||||
|
||||
class SummaryWriter(metaclass=abc.ABCMeta):
|
||||
@@ -36,7 +37,7 @@ def create_file_writer(
|
||||
) -> SummaryWriter: ...
|
||||
def create_noop_writer() -> SummaryWriter: ...
|
||||
def flush(writer: SummaryWriter | None = None, name: str | None = None) -> tf.Operation: ...
|
||||
def graph(graph_data: tf.Graph | tf.compat.v1.GraphDef) -> bool: ...
|
||||
def graph(graph_data: tf.Graph | GraphDef) -> bool: ...
|
||||
def histogram(
|
||||
name: str, data: tf.Tensor, step: int | None = None, buckets: int | None = None, description: str | None = None
|
||||
) -> bool: ...
|
||||
@@ -54,7 +55,7 @@ def should_record_summaries() -> bool: ...
|
||||
def text(name: str, data: str | tf.Tensor, step: int | tf.Tensor | None = None, description: str | None = None) -> bool: ...
|
||||
def trace_export(name: str, step: int | tf.Tensor | None = None, profiler_outdir: str | None = None) -> None: ...
|
||||
def trace_off() -> None: ...
|
||||
def trace_on(graph: bool = True, profiler: bool = False) -> None: ...
|
||||
def trace_on(graph: bool = True, profiler: bool = False, profiler_outdir: str | None = None) -> None: ...
|
||||
def write(
|
||||
tag: str, tensor: tf.Tensor, step: int | tf.Tensor | None = None, metadata: Incomplete | None = None, name: str | None = None
|
||||
) -> bool: ...
|
||||
|
||||
@@ -30,6 +30,8 @@ class CheckpointOptions:
|
||||
experimental_enable_async_checkpoint: bool = False,
|
||||
experimental_write_callbacks: None | list[Callable[[str], object] | Callable[[], object]] = None,
|
||||
enable_async: bool = False,
|
||||
experimental_skip_slot_variables: bool = False,
|
||||
experimental_sharding_callback: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
|
||||
_T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str])
|
||||
|
||||
@@ -1,253 +0,0 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
This file defines protos that store the results of autotuning various
|
||||
operations.
|
||||
|
||||
They are in proto format because we want to log them structured. They offer
|
||||
tremendous statistical, testing, and debugging value.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import sys
|
||||
import typing
|
||||
|
||||
import google.protobuf.any_pb2
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.duration_pb2
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import tensorflow.tsl.protobuf.dnn_pb2
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class CudnnVersion(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
MAJOR_FIELD_NUMBER: builtins.int
|
||||
MINOR_FIELD_NUMBER: builtins.int
|
||||
PATCH_FIELD_NUMBER: builtins.int
|
||||
major: builtins.int
|
||||
minor: builtins.int
|
||||
patch: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
major: builtins.int | None = ...,
|
||||
minor: builtins.int | None = ...,
|
||||
patch: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["major", b"major", "minor", b"minor", "patch", b"patch"]) -> None: ...
|
||||
|
||||
global___CudnnVersion = CudnnVersion
|
||||
|
||||
@typing.final
|
||||
class ComputeCapability(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
MAJOR_FIELD_NUMBER: builtins.int
|
||||
MINOR_FIELD_NUMBER: builtins.int
|
||||
major: builtins.int
|
||||
minor: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
major: builtins.int | None = ...,
|
||||
minor: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["major", b"major", "minor", b"minor"]) -> None: ...
|
||||
|
||||
global___ComputeCapability = ComputeCapability
|
||||
|
||||
@typing.final
|
||||
class AutotuneResult(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _FailureKind:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _FailureKindEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[AutotuneResult._FailureKind.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
UNKNOWN: AutotuneResult._FailureKind.ValueType # 0
|
||||
REDZONE_MODIFIED: AutotuneResult._FailureKind.ValueType # 1
|
||||
"""Algorithm wrote memory outside its output buffers."""
|
||||
WRONG_RESULT: AutotuneResult._FailureKind.ValueType # 2
|
||||
"""Algorithm gave a different result from a reference algorithm."""
|
||||
DISQUALIFIED: AutotuneResult._FailureKind.ValueType # 3
|
||||
"""Algorithm was rejected for failing to run or for known bugs."""
|
||||
|
||||
class FailureKind(_FailureKind, metaclass=_FailureKindEnumTypeWrapper): ...
|
||||
UNKNOWN: AutotuneResult.FailureKind.ValueType # 0
|
||||
REDZONE_MODIFIED: AutotuneResult.FailureKind.ValueType # 1
|
||||
"""Algorithm wrote memory outside its output buffers."""
|
||||
WRONG_RESULT: AutotuneResult.FailureKind.ValueType # 2
|
||||
"""Algorithm gave a different result from a reference algorithm."""
|
||||
DISQUALIFIED: AutotuneResult.FailureKind.ValueType # 3
|
||||
"""Algorithm was rejected for failing to run or for known bugs."""
|
||||
|
||||
@typing.final
|
||||
class FailureResult(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KIND_FIELD_NUMBER: builtins.int
|
||||
MSG_FIELD_NUMBER: builtins.int
|
||||
REFERENCE_CONV_FIELD_NUMBER: builtins.int
|
||||
REFERENCE_GEMM_FIELD_NUMBER: builtins.int
|
||||
REFERENCE_CUDA_CONV_PLAN_FIELD_NUMBER: builtins.int
|
||||
REFERENCE_ALGORITHM_FIELD_NUMBER: builtins.int
|
||||
BUFFER_ADDRESS_FIELD_NUMBER: builtins.int
|
||||
kind: global___AutotuneResult.FailureKind.ValueType
|
||||
msg: builtins.str
|
||||
buffer_address: builtins.int
|
||||
@property
|
||||
def reference_conv(self) -> global___AutotuneResult.ConvKey: ...
|
||||
@property
|
||||
def reference_gemm(self) -> global___AutotuneResult.GemmKey: ...
|
||||
@property
|
||||
def reference_cuda_conv_plan(self) -> global___AutotuneResult.CudaConvPlanKey: ...
|
||||
@property
|
||||
def reference_algorithm(self) -> tensorflow.tsl.protobuf.dnn_pb2.AlgorithmProto: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
kind: global___AutotuneResult.FailureKind.ValueType | None = ...,
|
||||
msg: builtins.str | None = ...,
|
||||
reference_conv: global___AutotuneResult.ConvKey | None = ...,
|
||||
reference_gemm: global___AutotuneResult.GemmKey | None = ...,
|
||||
reference_cuda_conv_plan: global___AutotuneResult.CudaConvPlanKey | None = ...,
|
||||
reference_algorithm: tensorflow.tsl.protobuf.dnn_pb2.AlgorithmProto | None = ...,
|
||||
buffer_address: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["key", b"key", "reference_algorithm", b"reference_algorithm", "reference_conv", b"reference_conv", "reference_cuda_conv_plan", b"reference_cuda_conv_plan", "reference_gemm", b"reference_gemm"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["buffer_address", b"buffer_address", "key", b"key", "kind", b"kind", "msg", b"msg", "reference_algorithm", b"reference_algorithm", "reference_conv", b"reference_conv", "reference_cuda_conv_plan", b"reference_cuda_conv_plan", "reference_gemm", b"reference_gemm"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["key", b"key"]) -> typing.Literal["reference_conv", "reference_gemm", "reference_cuda_conv_plan", "reference_algorithm"] | None: ...
|
||||
|
||||
@typing.final
|
||||
class ConvKey(google.protobuf.message.Message):
|
||||
"""Legacy and unused in new data; superseded by AlgorithmProto."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
ALGORITHM_FIELD_NUMBER: builtins.int
|
||||
TENSOR_OPS_ENABLED_FIELD_NUMBER: builtins.int
|
||||
algorithm: builtins.int
|
||||
tensor_ops_enabled: builtins.bool
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
algorithm: builtins.int | None = ...,
|
||||
tensor_ops_enabled: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["algorithm", b"algorithm", "tensor_ops_enabled", b"tensor_ops_enabled"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class GemmKey(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
ALGORITHM_FIELD_NUMBER: builtins.int
|
||||
algorithm: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
algorithm: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["algorithm", b"algorithm"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class CudaConvPlanKey(google.protobuf.message.Message):
|
||||
"""Legacy and unused in new data; superseded by AlgorithmProto."""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
EXEC_PLAN_ID_FIELD_NUMBER: builtins.int
|
||||
exec_plan_id: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
exec_plan_id: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["exec_plan_id", b"exec_plan_id"]) -> None: ...
|
||||
|
||||
SCRATCH_BYTES_FIELD_NUMBER: builtins.int
|
||||
RUN_TIME_FIELD_NUMBER: builtins.int
|
||||
FAILURE_FIELD_NUMBER: builtins.int
|
||||
CONV_FIELD_NUMBER: builtins.int
|
||||
GEMM_FIELD_NUMBER: builtins.int
|
||||
CUDA_CONV_PLAN_FIELD_NUMBER: builtins.int
|
||||
ALGORITHM_FIELD_NUMBER: builtins.int
|
||||
scratch_bytes: builtins.int
|
||||
@property
|
||||
def run_time(self) -> google.protobuf.duration_pb2.Duration: ...
|
||||
@property
|
||||
def failure(self) -> global___AutotuneResult.FailureResult: ...
|
||||
@property
|
||||
def conv(self) -> global___AutotuneResult.ConvKey: ...
|
||||
@property
|
||||
def gemm(self) -> global___AutotuneResult.GemmKey: ...
|
||||
@property
|
||||
def cuda_conv_plan(self) -> global___AutotuneResult.CudaConvPlanKey: ...
|
||||
@property
|
||||
def algorithm(self) -> tensorflow.tsl.protobuf.dnn_pb2.AlgorithmProto: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scratch_bytes: builtins.int | None = ...,
|
||||
run_time: google.protobuf.duration_pb2.Duration | None = ...,
|
||||
failure: global___AutotuneResult.FailureResult | None = ...,
|
||||
conv: global___AutotuneResult.ConvKey | None = ...,
|
||||
gemm: global___AutotuneResult.GemmKey | None = ...,
|
||||
cuda_conv_plan: global___AutotuneResult.CudaConvPlanKey | None = ...,
|
||||
algorithm: tensorflow.tsl.protobuf.dnn_pb2.AlgorithmProto | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["algorithm", b"algorithm", "conv", b"conv", "cuda_conv_plan", b"cuda_conv_plan", "failure", b"failure", "gemm", b"gemm", "key", b"key", "run_time", b"run_time"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["algorithm", b"algorithm", "conv", b"conv", "cuda_conv_plan", b"cuda_conv_plan", "failure", b"failure", "gemm", b"gemm", "key", b"key", "run_time", b"run_time", "scratch_bytes", b"scratch_bytes"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["key", b"key"]) -> typing.Literal["conv", "gemm", "cuda_conv_plan", "algorithm"] | None: ...
|
||||
|
||||
global___AutotuneResult = AutotuneResult
|
||||
|
||||
@typing.final
|
||||
class AutotuningLog(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
INSTR_FIELD_NUMBER: builtins.int
|
||||
RESULTS_FIELD_NUMBER: builtins.int
|
||||
CUDNN_VERSION_FIELD_NUMBER: builtins.int
|
||||
COMPUTE_CAPABILITY_FIELD_NUMBER: builtins.int
|
||||
DEVICE_PCI_BUS_ID_FIELD_NUMBER: builtins.int
|
||||
BLAS_VERSION_FIELD_NUMBER: builtins.int
|
||||
device_pci_bus_id: builtins.str
|
||||
"""stream_executor::DeviceDescription::pci_bus_id."""
|
||||
blas_version: builtins.str
|
||||
@property
|
||||
def instr(self) -> google.protobuf.any_pb2.Any: ...
|
||||
@property
|
||||
def results(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___AutotuneResult]:
|
||||
"""Records all auto-tuning results per algorithm."""
|
||||
|
||||
@property
|
||||
def cudnn_version(self) -> global___CudnnVersion: ...
|
||||
@property
|
||||
def compute_capability(self) -> global___ComputeCapability: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
instr: google.protobuf.any_pb2.Any | None = ...,
|
||||
results: collections.abc.Iterable[global___AutotuneResult] | None = ...,
|
||||
cudnn_version: global___CudnnVersion | None = ...,
|
||||
compute_capability: global___ComputeCapability | None = ...,
|
||||
device_pci_bus_id: builtins.str | None = ...,
|
||||
blas_version: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["compute_capability", b"compute_capability", "cudnn_version", b"cudnn_version", "instr", b"instr"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["blas_version", b"blas_version", "compute_capability", b"compute_capability", "cudnn_version", b"cudnn_version", "device_pci_bus_id", b"device_pci_bus_id", "instr", b"instr", "results", b"results"]) -> None: ...
|
||||
|
||||
global___AutotuningLog = AutotuningLog
|
||||
@@ -52,6 +52,8 @@ class CoordinationServiceConfig(google.protobuf.message.Message):
|
||||
SHUTDOWN_BARRIER_TIMEOUT_IN_MS_FIELD_NUMBER: builtins.int
|
||||
AGENT_DESTRUCTION_WITHOUT_SHUTDOWN_FIELD_NUMBER: builtins.int
|
||||
RECOVERABLE_JOBS_FIELD_NUMBER: builtins.int
|
||||
ALLOW_NEW_INCARNATION_TO_RECONNECT_FIELD_NUMBER: builtins.int
|
||||
FORCE_DISABLE_FIELD_NUMBER: builtins.int
|
||||
service_type: builtins.str
|
||||
"""Type of coordination service implementation to enable.
|
||||
For example, setting the service type as "standalone" starts a service
|
||||
@@ -82,6 +84,17 @@ class CoordinationServiceConfig(google.protobuf.message.Message):
|
||||
find out about the disconnecte agent via stale heartbeats. Used for
|
||||
testing.
|
||||
"""
|
||||
allow_new_incarnation_to_reconnect: builtins.bool
|
||||
"""If a task restarts with a new incarnation, we may allow it to reconnect
|
||||
silently. This is useful when we know that a task can immediately resume
|
||||
work upon re-connecting to the service.
|
||||
"""
|
||||
force_disable: builtins.bool
|
||||
"""Disables coordination service.
|
||||
Some libraries enable coordination service by default even if the user did
|
||||
not specify any config. This field allows users to explicitly disable
|
||||
coordination service under all situations.
|
||||
"""
|
||||
@property
|
||||
def coordinated_job_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___CoordinatedJob]: ...
|
||||
@property
|
||||
@@ -104,7 +117,9 @@ class CoordinationServiceConfig(google.protobuf.message.Message):
|
||||
shutdown_barrier_timeout_in_ms: builtins.int | None = ...,
|
||||
agent_destruction_without_shutdown: builtins.bool | None = ...,
|
||||
recoverable_jobs: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
allow_new_incarnation_to_reconnect: builtins.bool | None = ...,
|
||||
force_disable: builtins.bool | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["agent_destruction_without_shutdown", b"agent_destruction_without_shutdown", "cluster_register_timeout_in_ms", b"cluster_register_timeout_in_ms", "coordinated_job_list", b"coordinated_job_list", "enable_health_check", b"enable_health_check", "heartbeat_timeout_in_ms", b"heartbeat_timeout_in_ms", "recoverable_jobs", b"recoverable_jobs", "service_leader", b"service_leader", "service_type", b"service_type", "shutdown_barrier_timeout_in_ms", b"shutdown_barrier_timeout_in_ms"]) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["agent_destruction_without_shutdown", b"agent_destruction_without_shutdown", "allow_new_incarnation_to_reconnect", b"allow_new_incarnation_to_reconnect", "cluster_register_timeout_in_ms", b"cluster_register_timeout_in_ms", "coordinated_job_list", b"coordinated_job_list", "enable_health_check", b"enable_health_check", "force_disable", b"force_disable", "heartbeat_timeout_in_ms", b"heartbeat_timeout_in_ms", "recoverable_jobs", b"recoverable_jobs", "service_leader", b"service_leader", "service_type", b"service_type", "shutdown_barrier_timeout_in_ms", b"shutdown_barrier_timeout_in_ms"]) -> None: ...
|
||||
|
||||
global___CoordinationServiceConfig = CoordinationServiceConfig
|
||||
|
||||
@@ -37,6 +37,9 @@ class _DataTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumT
|
||||
kBF16: _DataType.ValueType # 7
|
||||
kF8E5M2: _DataType.ValueType # 8
|
||||
kF8E4M3FN: _DataType.ValueType # 9
|
||||
kF8E5M2FNUZ: _DataType.ValueType # 10
|
||||
kF8E4M3FNUZ: _DataType.ValueType # 11
|
||||
kInt64: _DataType.ValueType # 12
|
||||
|
||||
class DataType(_DataType, metaclass=_DataTypeEnumTypeWrapper):
|
||||
"""Specifies the data type used by an operation."""
|
||||
@@ -51,6 +54,9 @@ kComplexDouble: DataType.ValueType # 6
|
||||
kBF16: DataType.ValueType # 7
|
||||
kF8E5M2: DataType.ValueType # 8
|
||||
kF8E4M3FN: DataType.ValueType # 9
|
||||
kF8E5M2FNUZ: DataType.ValueType # 10
|
||||
kF8E4M3FNUZ: DataType.ValueType # 11
|
||||
kInt64: DataType.ValueType # 12
|
||||
global___DataType = DataType
|
||||
|
||||
class _DataLayout:
|
||||
@@ -132,6 +138,10 @@ class _FilterLayoutEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._E
|
||||
"""cuDNN's NCHW_VECT_C layout with 4-elem vectors"""
|
||||
kOutputInputYX32: _FilterLayout.ValueType # 5
|
||||
"""cuDNN's NCHW_VECT_C layout with 32-elem vectors"""
|
||||
kOutputInputYX32_CudnnReordered: _FilterLayout.ValueType # 6
|
||||
"""cuDNN-specific filter reordering (using `cudnnReorderFilterAndBias`)
|
||||
When the filter is reordered, so is the bias (if present).
|
||||
"""
|
||||
kInputYXOutput: _FilterLayout.ValueType # 3
|
||||
kYXInputOutput: _FilterLayout.ValueType # 4
|
||||
|
||||
@@ -153,6 +163,10 @@ kOutputInputYX4: FilterLayout.ValueType # 2
|
||||
"""cuDNN's NCHW_VECT_C layout with 4-elem vectors"""
|
||||
kOutputInputYX32: FilterLayout.ValueType # 5
|
||||
"""cuDNN's NCHW_VECT_C layout with 32-elem vectors"""
|
||||
kOutputInputYX32_CudnnReordered: FilterLayout.ValueType # 6
|
||||
"""cuDNN-specific filter reordering (using `cudnnReorderFilterAndBias`)
|
||||
When the filter is reordered, so is the bias (if present).
|
||||
"""
|
||||
kInputYXOutput: FilterLayout.ValueType # 3
|
||||
kYXInputOutput: FilterLayout.ValueType # 4
|
||||
global___FilterLayout = FilterLayout
|
||||
@@ -241,6 +255,7 @@ class _ConvolutionKindEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper
|
||||
BACKWARD_FILTER: _ConvolutionKind.ValueType # 2
|
||||
BACKWARD_DATA: _ConvolutionKind.ValueType # 3
|
||||
FORWARD_BIAS_ACTIVATION: _ConvolutionKind.ValueType # 4
|
||||
FORWARD_GRAPH: _ConvolutionKind.ValueType # 5
|
||||
|
||||
class ConvolutionKind(_ConvolutionKind, metaclass=_ConvolutionKindEnumTypeWrapper): ...
|
||||
|
||||
@@ -249,8 +264,27 @@ FORWARD: ConvolutionKind.ValueType # 1
|
||||
BACKWARD_FILTER: ConvolutionKind.ValueType # 2
|
||||
BACKWARD_DATA: ConvolutionKind.ValueType # 3
|
||||
FORWARD_BIAS_ACTIVATION: ConvolutionKind.ValueType # 4
|
||||
FORWARD_GRAPH: ConvolutionKind.ValueType # 5
|
||||
global___ConvolutionKind = ConvolutionKind
|
||||
|
||||
class _FusedMHAKind:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _FusedMHAKindEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_FusedMHAKind.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
BMM1_OUTPUT_UNKNOWN: _FusedMHAKind.ValueType # 0
|
||||
BMM1_OUTPUT_INPUT_TYPE: _FusedMHAKind.ValueType # 1
|
||||
BMM1_OUTPUT_FLOAT: _FusedMHAKind.ValueType # 2
|
||||
|
||||
class FusedMHAKind(_FusedMHAKind, metaclass=_FusedMHAKindEnumTypeWrapper):
|
||||
"""FusedMHAKind kind"""
|
||||
|
||||
BMM1_OUTPUT_UNKNOWN: FusedMHAKind.ValueType # 0
|
||||
BMM1_OUTPUT_INPUT_TYPE: FusedMHAKind.ValueType # 1
|
||||
BMM1_OUTPUT_FLOAT: FusedMHAKind.ValueType # 2
|
||||
global___FusedMHAKind = FusedMHAKind
|
||||
|
||||
@typing.final
|
||||
class TensorDescriptorProto(google.protobuf.message.Message):
|
||||
"""Generic tensor representation."""
|
||||
|
||||
37
stubs/tensorflow/tensorflow/tsl/protobuf/status_pb2.pyi
Normal file
37
stubs/tensorflow/tensorflow/tsl/protobuf/status_pb2.pyi
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import typing
|
||||
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.message
|
||||
import tensorflow.tsl.protobuf.error_codes_pb2
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class StatusProto(google.protobuf.message.Message):
|
||||
"""Wire-format for Status.
|
||||
Next tag: 3
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
CODE_FIELD_NUMBER: builtins.int
|
||||
MESSAGE_FIELD_NUMBER: builtins.int
|
||||
code: tensorflow.tsl.protobuf.error_codes_pb2.Code.ValueType
|
||||
"""Status code as defined in tensorflow/tsl/protobuf/error_codes.proto."""
|
||||
message: builtins.str
|
||||
"""Detail error message."""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
code: tensorflow.tsl.protobuf.error_codes_pb2.Code.ValueType | None = ...,
|
||||
message: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["code", b"code", "message", b"message"]) -> None: ...
|
||||
|
||||
global___StatusProto = StatusProto
|
||||
1
stubs/tensorflow/tensorflow/types/__init__.pyi
Normal file
1
stubs/tensorflow/tensorflow/types/__init__.pyi
Normal file
@@ -0,0 +1 @@
|
||||
from tensorflow.types import experimental as experimental
|
||||
@@ -15,7 +15,7 @@ class Callable(Generic[_P, _R], metaclass=abc.ABCMeta):
|
||||
class ConcreteFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
|
||||
|
||||
class GenericFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
|
||||
class PolymorphicFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
|
||||
@overload
|
||||
@abc.abstractmethod
|
||||
def get_concrete_function(self, *args: _P.args, **kwargs: _P.kwargs) -> ConcreteFunction[_P, _R]: ...
|
||||
@@ -26,4 +26,6 @@ class GenericFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
|
||||
) -> ConcreteFunction[_P, _R]: ...
|
||||
def experimental_get_compiler_ir(self, *args, **kwargs): ...
|
||||
|
||||
GenericFunction = PolymorphicFunction
|
||||
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
|
||||
Reference in New Issue
Block a user