diff --git a/stubs/seaborn/seaborn/_stats/density.pyi b/stubs/seaborn/seaborn/_stats/density.pyi index 7e8d12e30..386424ccb 100644 --- a/stubs/seaborn/seaborn/_stats/density.pyi +++ b/stubs/seaborn/seaborn/_stats/density.pyi @@ -1,13 +1,12 @@ -from collections.abc import Callable from dataclasses import dataclass from seaborn._stats.base import Stat -from seaborn.external.kde import gaussian_kde +from seaborn.external.kde import _BwMethodType @dataclass class KDE(Stat): bw_adjust: float = 1 - bw_method: str | float | Callable[[gaussian_kde], float] = "scott" + bw_method: _BwMethodType = "scott" common_norm: bool | list[str] = True common_grid: bool | list[str] = True gridsize: int | None = 200 diff --git a/stubs/seaborn/seaborn/distributions.pyi b/stubs/seaborn/seaborn/distributions.pyi index 78c20b4d5..0f849f2c1 100644 --- a/stubs/seaborn/seaborn/distributions.pyi +++ b/stubs/seaborn/seaborn/distributions.pyi @@ -1,28 +1,37 @@ -from _typeshed import Incomplete from collections.abc import Iterable -from typing import Any -from typing_extensions import Literal, deprecated +from typing import Any, Protocol, TypeVar +from typing_extensions import Literal, TypeAlias, deprecated from matplotlib.axes import Axes -from matplotlib.colors import Colormap +from matplotlib.colors import Colormap, Normalize from matplotlib.typing import ColorType +from numpy.typing import ArrayLike +from ._core.typing import ColumnName, DataSource from .axisgrid import FacetGrid -from .utils import _LogScale, _Palette +from .external.kde import _BwMethodType +from .utils import _DataSourceWideForm, _LogScale, _Palette, _Vector __all__ = ["displot", "histplot", "kdeplot", "ecdfplot", "rugplot", "distplot"] +_T = TypeVar("_T") +_OneOrPair: TypeAlias = _T | tuple[_T, _T] + +class _Fit(Protocol): + def fit(self, a: ArrayLike) -> tuple[ArrayLike, ...]: ... + def pdf(self, x: ArrayLike, *params: ArrayLike) -> ArrayLike: ... + def histplot( - data: Incomplete | None = None, + data: DataSource | _DataSourceWideForm | None = None, *, - x: Incomplete | None = None, - y: Incomplete | None = None, - hue: Incomplete | None = None, - weights: Incomplete | None = None, + x: ColumnName | _Vector | None = None, + y: ColumnName | _Vector | None = None, + hue: ColumnName | _Vector | None = None, + weights: ColumnName | _Vector | None = None, stat: str = "count", - bins: Incomplete = "auto", + bins: _OneOrPair[str | int | ArrayLike] = "auto", binwidth: float | tuple[float, float] | None = None, - binrange: Incomplete | None = None, + binrange: _OneOrPair[tuple[float, float]] | None = None, discrete: bool | None = None, cumulative: bool = False, common_bins: bool = True, @@ -41,8 +50,8 @@ def histplot( cbar_ax: Axes | None = None, cbar_kws: dict[str, Any] | None = None, palette: _Palette | Colormap | None = None, - hue_order: Iterable[str] | None = None, - hue_norm: Incomplete | None = None, + hue_order: Iterable[ColumnName] | None = None, + hue_norm: tuple[float, float] | Normalize | None = None, color: ColorType | None = None, log_scale: _LogScale | None = None, legend: bool = True, @@ -50,30 +59,30 @@ def histplot( **kwargs: Any, ) -> Axes: ... def kdeplot( - data: Incomplete | None = None, + data: DataSource | _DataSourceWideForm | None = None, *, - x: Incomplete | None = None, - y: Incomplete | None = None, - hue: Incomplete | None = None, - weights: Incomplete | None = None, + x: ColumnName | _Vector | None = None, + y: ColumnName | _Vector | None = None, + hue: ColumnName | _Vector | None = None, + weights: ColumnName | _Vector | None = None, palette: _Palette | Colormap | None = None, - hue_order: Iterable[str] | None = None, - hue_norm: Incomplete | None = None, + hue_order: Iterable[ColumnName] | None = None, + hue_norm: tuple[float, float] | Normalize | None = None, color: ColorType | None = None, fill: bool | None = None, multiple: Literal["layer", "stack", "fill"] = "layer", common_norm: bool = True, common_grid: bool = False, cumulative: bool = False, - bw_method: str = "scott", + bw_method: _BwMethodType = "scott", bw_adjust: float = 1, warn_singular: bool = True, log_scale: _LogScale | None = None, - levels: int | Iterable[int] = 10, + levels: int | Iterable[float] = 10, thresh: float = 0.05, gridsize: int = 200, cut: float = 3, - clip: Incomplete | None = None, + clip: _OneOrPair[tuple[float | None, float | None]] | None = None, legend: bool = True, cbar: bool = False, cbar_ax: Axes | None = None, @@ -82,58 +91,58 @@ def kdeplot( **kwargs: Any, ) -> Axes: ... def ecdfplot( - data: Incomplete | None = None, + data: DataSource | _DataSourceWideForm | None = None, *, - x: Incomplete | None = None, - y: Incomplete | None = None, - hue: Incomplete | None = None, - weights: Incomplete | None = None, - stat: Literal["proportion", "count"] = "proportion", + x: ColumnName | _Vector | None = None, + y: ColumnName | _Vector | None = None, + hue: ColumnName | _Vector | None = None, + weights: ColumnName | _Vector | None = None, + stat: Literal["proportion", "percent", "count"] = "proportion", complementary: bool = False, palette: _Palette | Colormap | None = None, - hue_order: Iterable[str] | None = None, - hue_norm: Incomplete | None = None, + hue_order: Iterable[ColumnName] | None = None, + hue_norm: tuple[float, float] | Normalize | None = None, log_scale: _LogScale | None = None, legend: bool = True, ax: Axes | None = None, **kwargs: Any, ) -> Axes: ... def rugplot( - data: Incomplete | None = None, + data: DataSource | _DataSourceWideForm | None = None, *, - x: Incomplete | None = None, - y: Incomplete | None = None, - hue: Incomplete | None = None, + x: ColumnName | _Vector | None = None, + y: ColumnName | _Vector | None = None, + hue: ColumnName | _Vector | None = None, height: float = 0.025, expand_margins: bool = True, palette: _Palette | Colormap | None = None, - hue_order: Iterable[str] | None = None, - hue_norm: Incomplete | None = None, + hue_order: Iterable[ColumnName] | None = None, + hue_norm: tuple[float, float] | Normalize | None = None, legend: bool = True, ax: Axes | None = None, **kwargs: Any, ) -> Axes: ... def displot( - data: Incomplete | None = None, + data: DataSource | _DataSourceWideForm | None = None, *, - x: Incomplete | None = None, - y: Incomplete | None = None, - hue: Incomplete | None = None, - row: Incomplete | None = None, - col: Incomplete | None = None, - weights: Incomplete | None = None, + x: ColumnName | _Vector | None = None, + y: ColumnName | _Vector | None = None, + hue: ColumnName | _Vector | None = None, + row: ColumnName | _Vector | None = None, + col: ColumnName | _Vector | None = None, + weights: ColumnName | _Vector | None = None, kind: Literal["hist", "kde", "ecdf"] = "hist", rug: bool = False, rug_kws: dict[str, Any] | None = None, log_scale: _LogScale | None = None, legend: bool = True, palette: _Palette | Colormap | None = None, - hue_order: Iterable[str] | None = None, - hue_norm: Incomplete | None = None, + hue_order: Iterable[ColumnName] | None = None, + hue_norm: tuple[float, float] | Normalize | None = None, color: ColorType | None = None, col_wrap: int | None = None, - row_order: Iterable[str] | None = None, - col_order: Iterable[str] | None = None, + row_order: Iterable[ColumnName] | None = None, + col_order: Iterable[ColumnName] | None = None, height: float = 5, aspect: float = 1, facet_kws: dict[str, Any] | None = None, @@ -141,12 +150,12 @@ def displot( ) -> FacetGrid: ... @deprecated("Function `distplot` is deprecated and will be removed in seaborn v0.14.0") def distplot( - a: Incomplete | None = None, - bins: Incomplete | None = None, + a: ArrayLike | None = None, + bins: ArrayLike | None = None, hist: bool = True, kde: bool = True, rug: bool = False, - fit: Incomplete | None = None, + fit: _Fit | None = None, hist_kws: dict[str, Any] | None = None, kde_kws: dict[str, Any] | None = None, rug_kws: dict[str, Any] | None = None, @@ -155,7 +164,7 @@ def distplot( vertical: bool = False, norm_hist: bool = False, axlabel: str | Literal[False] | None = None, - label: Incomplete | None = None, + label: str | None = None, ax: Axes | None = None, - x: Incomplete | None = None, + x: ArrayLike | None = None, ) -> Axes: ... diff --git a/stubs/seaborn/seaborn/external/kde.pyi b/stubs/seaborn/seaborn/external/kde.pyi index 92d622779..81b757304 100644 --- a/stubs/seaborn/seaborn/external/kde.pyi +++ b/stubs/seaborn/seaborn/external/kde.pyi @@ -1,4 +1,5 @@ from collections.abc import Callable +from typing import Any, Protocol from typing_extensions import Literal, TypeAlias import numpy as np @@ -6,8 +7,25 @@ from numpy.typing import ArrayLike, NDArray __all__ = ["gaussian_kde"] -_Scalar: TypeAlias = np.generic | bool | int | float | complex | bytes | memoryview # see np.isscalar -_BwMethodType: TypeAlias = Literal["scott", "silverman"] | Callable[[gaussian_kde], object] | _Scalar | None +# define a "Gaussian KDE" protocol so that we can also pass `scipy.stats.gaussian_kde` to +# functions that expect it without adding a dependency on scipy +class _GaussianKDELike(Protocol): + dataset: NDArray[np.float64] + def __init__(self, dataset: ArrayLike, bw_method: Any | None = ..., weights: ArrayLike | None = ...) -> None: ... + def evaluate(self, points: ArrayLike) -> NDArray[Any]: ... + def __call__(self, points: ArrayLike) -> NDArray[Any]: ... + def scotts_factor(self) -> float: ... + def silverman_factor(self) -> float: ... + def covariance_factor(self) -> float: ... + def pdf(self, x: ArrayLike) -> NDArray[Any]: ... + def set_bandwidth(self, bw_method: Any | None = ...) -> None: ... + @property + def weights(self) -> NDArray[Any]: ... + @property + def neff(self) -> NDArray[Any]: ... + +_Scalar: TypeAlias = float | np.number[Any] +_BwMethodType: TypeAlias = Literal["scott", "silverman"] | Callable[[_GaussianKDELike], _Scalar] | _Scalar | None class gaussian_kde: dataset: NDArray[np.float64]