mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-20 10:51:15 +08:00
tensorflow: Add members from tensorflow.keras.metrics (#11329)
Partially taken from: https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/keras/metrics.pyi
This commit is contained in:
@@ -1,9 +1,120 @@
|
||||
from tensorflow import Tensor
|
||||
from tensorflow._aliases import TensorCompatible
|
||||
from _typeshed import Incomplete
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing_extensions import Self, TypeAlias, override
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import Operation, Tensor
|
||||
from tensorflow._aliases import DTypeLike, KerasSerializable, TensorCompatible
|
||||
from tensorflow.keras.initializers import _Initializer
|
||||
|
||||
_Output: TypeAlias = Tensor | dict[str, Tensor]
|
||||
|
||||
class Metric(tf.keras.layers.Layer[tf.Tensor, tf.Tensor], metaclass=ABCMeta):
|
||||
def __init__(self, name: str | None = None, dtype: DTypeLike | None = None) -> None: ...
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Self: ...
|
||||
def merge_state(self, metrics: Iterable[Self]) -> list[Operation]: ...
|
||||
def reset_state(self) -> None: ...
|
||||
@abstractmethod
|
||||
def update_state(
|
||||
self, y_true: TensorCompatible, y_pred: TensorCompatible, sample_weight: TensorCompatible | None = None
|
||||
) -> Operation | None: ...
|
||||
@abstractmethod
|
||||
def result(self) -> _Output: ...
|
||||
# Metric inherits from keras.Layer, but its add_weight method is incompatible with the one from "Layer".
|
||||
@override
|
||||
def add_weight( # type: ignore
|
||||
self,
|
||||
name: str,
|
||||
shape: Iterable[int | None] | None = (),
|
||||
aggregation: tf.VariableAggregation = ...,
|
||||
synchronization: tf.VariableSynchronization = ...,
|
||||
initializer: _Initializer | None = None,
|
||||
dtype: DTypeLike | None = None,
|
||||
) -> None: ...
|
||||
|
||||
class AUC(Metric):
|
||||
_from_logits: bool
|
||||
_num_labels: int
|
||||
num_labels: int | None
|
||||
def __init__(
|
||||
self,
|
||||
num_thresholds: int = 200,
|
||||
curve: Literal["ROC", "PR"] = "ROC",
|
||||
summation_method: Literal["interpolation", "minoring", "majoring"] = "interpolation",
|
||||
name: str | None = None,
|
||||
dtype: DTypeLike | None = None,
|
||||
thresholds: Sequence[float] | None = None,
|
||||
multi_label: bool = False,
|
||||
num_labels: int | None = None,
|
||||
label_weights: TensorCompatible | None = None,
|
||||
from_logits: bool = False,
|
||||
) -> None: ...
|
||||
def update_state(
|
||||
self, y_true: TensorCompatible, y_pred: TensorCompatible, sample_weight: TensorCompatible | None = None
|
||||
) -> Operation: ...
|
||||
def result(self) -> tf.Tensor: ...
|
||||
|
||||
class Precision(Metric):
|
||||
def __init__(
|
||||
self,
|
||||
thresholds: float | Sequence[float] | None = None,
|
||||
top_k: int | None = None,
|
||||
class_id: int | None = None,
|
||||
name: str | None = None,
|
||||
dtype: DTypeLike | None = None,
|
||||
) -> None: ...
|
||||
def update_state(
|
||||
self, y_true: TensorCompatible, y_pred: TensorCompatible, sample_weight: TensorCompatible | None = None
|
||||
) -> Operation: ...
|
||||
def result(self) -> tf.Tensor: ...
|
||||
|
||||
class Recall(Metric):
|
||||
def __init__(
|
||||
self,
|
||||
thresholds: float | Sequence[float] | None = None,
|
||||
top_k: int | None = None,
|
||||
class_id: int | None = None,
|
||||
name: str | None = None,
|
||||
dtype: DTypeLike | None = None,
|
||||
) -> None: ...
|
||||
def update_state(
|
||||
self, y_true: TensorCompatible, y_pred: TensorCompatible, sample_weight: TensorCompatible | None = None
|
||||
) -> Operation: ...
|
||||
def result(self) -> tf.Tensor: ...
|
||||
|
||||
class MeanMetricWrapper(Metric):
|
||||
def __init__(
|
||||
self, fn: Callable[[tf.Tensor, tf.Tensor], tf.Tensor], name: str | None = None, dtype: DTypeLike | None = None
|
||||
) -> None: ...
|
||||
def update_state(
|
||||
self, y_true: TensorCompatible, y_pred: TensorCompatible, sample_weight: TensorCompatible | None = None
|
||||
) -> Operation: ...
|
||||
def result(self) -> tf.Tensor: ...
|
||||
|
||||
class BinaryAccuracy(MeanMetricWrapper):
|
||||
def __init__(self, name: str | None = "binary_accuracy", dtype: DTypeLike | None = None, threshold: float = 0.5) -> None: ...
|
||||
|
||||
class Accuracy(MeanMetricWrapper):
|
||||
def __init__(self, name: str | None = "accuracy", dtype: DTypeLike | None = None) -> None: ...
|
||||
|
||||
class CategoricalAccuracy(MeanMetricWrapper):
|
||||
def __init__(self, name: str | None = "categorical_accuracy", dtype: DTypeLike | None = None) -> None: ...
|
||||
|
||||
class TopKCategoricalAccuracy(MeanMetricWrapper):
|
||||
def __init__(self, k: int = 5, name: str | None = "top_k_categorical_accuracy", dtype: DTypeLike | None = None) -> None: ...
|
||||
|
||||
class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
|
||||
def __init__(
|
||||
self, k: int = 5, name: str | None = "sparse_top_k_categorical_accuracy", dtype: DTypeLike | None = None
|
||||
) -> None: ...
|
||||
|
||||
def serialize(metric: KerasSerializable, use_legacy_format: bool = False) -> dict[str, Any]: ...
|
||||
def binary_crossentropy(
|
||||
y_true: TensorCompatible, y_pred: TensorCompatible, from_logits: bool = False, label_smoothing: float = 0.0, axis: int = -1
|
||||
) -> Tensor: ...
|
||||
def categorical_crossentropy(
|
||||
y_true: TensorCompatible, y_pred: TensorCompatible, from_logits: bool = False, label_smoothing: float = 0.0, axis: int = -1
|
||||
) -> Tensor: ...
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
|
||||
Reference in New Issue
Block a user