mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-08 13:04:46 +08:00
Tensorflow losses (#10264)
Co-authored-by: Mehdi Drissi <mdrissi@snapchat.com>
This commit is contained in:
@@ -26,6 +26,7 @@ from tensorflow.core.protobuf import struct_pb2
|
||||
# is necessary to avoid a crash in pytype.
|
||||
from tensorflow.dtypes import *
|
||||
from tensorflow.dtypes import DType as DType
|
||||
from tensorflow.keras import losses as losses
|
||||
|
||||
# Most tf.math functions are exported as tf, but sadly not all are.
|
||||
from tensorflow.math import (
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# equivalent.
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, Protocol, TypeVar
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import numpy
|
||||
@@ -20,3 +20,11 @@ ContainerTensors: TypeAlias = ContainerGeneric[tf.Tensor]
|
||||
ContainerGradients: TypeAlias = ContainerGeneric[Gradients]
|
||||
|
||||
AnyArray: TypeAlias = numpy.ndarray[Any, Any]
|
||||
|
||||
class _KerasSerializable1(Protocol):
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
|
||||
class _KerasSerializable2(Protocol):
|
||||
__name__: str
|
||||
|
||||
KerasSerializable: TypeAlias = _KerasSerializable1 | _KerasSerializable2
|
||||
|
||||
@@ -5,6 +5,8 @@ from tensorflow.keras import (
|
||||
constraints as constraints,
|
||||
initializers as initializers,
|
||||
layers as layers,
|
||||
losses as losses,
|
||||
metrics as metrics,
|
||||
optimizers as optimizers,
|
||||
regularizers as regularizers,
|
||||
)
|
||||
|
||||
153
stubs/tensorflow/tensorflow/keras/losses.pyi
Normal file
153
stubs/tensorflow/tensorflow/keras/losses.pyi
Normal file
@@ -0,0 +1,153 @@
|
||||
from _typeshed import Incomplete
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar, overload
|
||||
from typing_extensions import Final, Literal, Self, TypeAlias, TypeGuard
|
||||
|
||||
from tensorflow import Tensor, _TensorCompatible
|
||||
from tensorflow._aliases import KerasSerializable
|
||||
from tensorflow.keras.metrics import (
|
||||
binary_crossentropy as binary_crossentropy,
|
||||
categorical_crossentropy as categorical_crossentropy,
|
||||
)
|
||||
|
||||
class Loss(ABC):
|
||||
reduction: _ReductionValues
|
||||
name: str | None
|
||||
def __init__(self, reduction: _ReductionValues = "auto", name: str | None = None) -> None: ...
|
||||
@abstractmethod
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> Self: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
def __call__(
|
||||
self, y_true: _TensorCompatible, y_pred: _TensorCompatible, sample_weight: _TensorCompatible | None = None
|
||||
) -> Tensor: ...
|
||||
|
||||
class BinaryCrossentropy(Loss):
|
||||
def __init__(
|
||||
self,
|
||||
from_logits: bool = False,
|
||||
label_smoothing: float = 0.0,
|
||||
axis: int = -1,
|
||||
reduction: _ReductionValues = ...,
|
||||
name: str | None = "binary_crossentropy",
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class BinaryFocalCrossentropy(Loss):
|
||||
def __init__(
|
||||
self,
|
||||
apply_class_balancing: bool = False,
|
||||
alpha: float = 0.25,
|
||||
gamma: float = 2.0,
|
||||
from_logits: bool = False,
|
||||
label_smoothing: float = 0.0,
|
||||
axis: int = -1,
|
||||
reduction: _ReductionValues = ...,
|
||||
name: str | None = "binary_focal_crossentropy",
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class CategoricalCrossentropy(Loss):
|
||||
def __init__(
|
||||
self,
|
||||
from_logits: bool = False,
|
||||
label_smoothing: float = 0.0,
|
||||
axis: int = -1,
|
||||
reduction: _ReductionValues = ...,
|
||||
name: str | None = "categorical_crossentropy",
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class CategoricalHinge(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "categorical_hinge") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class CosineSimilarity(Loss):
|
||||
def __init__(self, axis: int = -1, reduction: _ReductionValues = ..., name: str | None = "cosine_similarity") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class Hinge(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "hinge") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class Huber(Loss):
|
||||
def __init__(self, delta: float = 1.0, reduction: _ReductionValues = ..., name: str | None = "huber_loss") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class KLDivergence(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "kl_divergence") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class LogCosh(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "log_cosh") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class MeanAbsoluteError(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_absolute_error") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class MeanAbsolutePercentageError(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_absolute_percentage_error") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class MeanSquaredError(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_squared_error") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class MeanSquaredLogarithmicError(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_squared_logarithmic_error") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class Poisson(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "poisson") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class SparseCategoricalCrossentropy(Loss):
|
||||
def __init__(
|
||||
self,
|
||||
from_logits: bool = False,
|
||||
ignore_class: int | None = None,
|
||||
reduction: _ReductionValues = ...,
|
||||
name: str = "sparse_categorical_crossentropy",
|
||||
) -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class SquaredHinge(Loss):
|
||||
def __init__(self, reduction: _ReductionValues = ..., name: str | None = "squared_hinge") -> None: ...
|
||||
def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
|
||||
|
||||
class Reduction:
|
||||
AUTO: Final = "auto"
|
||||
NONE: Final = "none"
|
||||
SUM: Final = "sum"
|
||||
SUM_OVER_BATCH_SIZE: Final = "sum_over_batch_size"
|
||||
@classmethod
|
||||
def all(cls) -> tuple[_ReductionValues, ...]: ...
|
||||
@classmethod
|
||||
def validate(cls, key: object) -> TypeGuard[_ReductionValues]: ...
|
||||
|
||||
_ReductionValues: TypeAlias = Literal["auto", "none", "sum", "sum_over_batch_size"]
|
||||
|
||||
def categorical_hinge(y_true: _TensorCompatible, y_pred: _TensorCompatible) -> Tensor: ...
|
||||
def huber(y_true: _TensorCompatible, y_pred: _TensorCompatible, delta: float = 1.0) -> Tensor: ...
|
||||
def log_cosh(y_true: _TensorCompatible, y_pred: _TensorCompatible) -> Tensor: ...
|
||||
def deserialize(
|
||||
name: str | dict[str, Any], custom_objects: dict[str, Any] | None = None, use_legacy_format: bool = False
|
||||
) -> Loss: ...
|
||||
def serialize(loss: KerasSerializable, use_legacy_format: bool = False) -> dict[str, Any]: ...
|
||||
|
||||
_FuncT = TypeVar("_FuncT", bound=Callable[..., Any])
|
||||
|
||||
@overload
|
||||
def get(identifier: None) -> None: ...
|
||||
@overload
|
||||
def get(identifier: str | dict[str, Any]) -> Loss: ...
|
||||
@overload
|
||||
def get(identifier: _FuncT) -> _FuncT: ...
|
||||
|
||||
# This is complete with respect to methods documented defined here,
|
||||
# but many methods get re-exported here from tf.keras.metrics that aren't
|
||||
# covered yet.
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
8
stubs/tensorflow/tensorflow/keras/metrics.pyi
Normal file
8
stubs/tensorflow/tensorflow/keras/metrics.pyi
Normal file
@@ -0,0 +1,8 @@
|
||||
from tensorflow import Tensor, _TensorCompatible
|
||||
|
||||
def binary_crossentropy(
|
||||
y_true: _TensorCompatible, y_pred: _TensorCompatible, from_logits: bool = False, label_smoothing: float = 0.0, axis: int = -1
|
||||
) -> Tensor: ...
|
||||
def categorical_crossentropy(
|
||||
y_true: _TensorCompatible, y_pred: _TensorCompatible, from_logits: bool = False, label_smoothing: float = 0.0, axis: int = -1
|
||||
) -> Tensor: ...
|
||||
Reference in New Issue
Block a user