From aef6e229fe9bd7c0f0bff00c98c538fac16bc637 Mon Sep 17 00:00:00 2001 From: Ali Hamdan Date: Tue, 21 Nov 2023 18:51:25 +0200 Subject: [PATCH] seaborn: fix and complete `seaborn.regression` (#11043) --- stubs/seaborn/@tests/stubtest_allowlist.txt | 2 + stubs/seaborn/seaborn/regression.pyi | 102 +++++++++++++++----- 2 files changed, 81 insertions(+), 23 deletions(-) diff --git a/stubs/seaborn/@tests/stubtest_allowlist.txt b/stubs/seaborn/@tests/stubtest_allowlist.txt index 02bf781f5..c435d9f95 100644 --- a/stubs/seaborn/@tests/stubtest_allowlist.txt +++ b/stubs/seaborn/@tests/stubtest_allowlist.txt @@ -1,3 +1,5 @@ seaborn._core.scales.(Pipeline|TransFuncs) # aliases defined in `if TYPE_CHECKING` block seaborn.external.docscrape.ClassDoc.__init__ # stubtest doesn't like ABC class as default value seaborn.external.docscrape.NumpyDocString.__str__ # weird signature + +seaborn(\.regression)?\.lmplot # the `data` argument is required but it defaults to `None` at runtime diff --git a/stubs/seaborn/seaborn/regression.pyi b/stubs/seaborn/seaborn/regression.pyi index 8b8eb9f13..dc7586176 100644 --- a/stubs/seaborn/seaborn/regression.pyi +++ b/stubs/seaborn/seaborn/regression.pyi @@ -1,19 +1,22 @@ from _typeshed import Incomplete -from collections.abc import Iterable -from typing import Any -from typing_extensions import Literal +from collections.abc import Callable, Iterable +from typing import Any, overload +from typing_extensions import Literal, TypeAlias import pandas as pd from matplotlib.axes import Axes from matplotlib.typing import ColorType +from numpy.typing import NDArray from .axisgrid import FacetGrid from .utils import _Palette, _Seed __all__ = ["lmplot", "regplot", "residplot"] +_Vector: TypeAlias = list[Incomplete] | pd.Series[Incomplete] | pd.Index[Incomplete] | NDArray[Incomplete] + def lmplot( - data: Incomplete | None = None, + data: pd.DataFrame, *, x: str | None = None, y: str | None = None, @@ -25,15 +28,15 @@ def lmplot( height: float = 5, aspect: float = 1, markers: str = "o", - sharex: bool | Literal["col", "row"] | None = None, - sharey: bool | Literal["col", "row"] | None = None, + sharex: bool | Literal["col", "row"] | None = None, # deprecated + sharey: bool | Literal["col", "row"] | None = None, # deprecated hue_order: Iterable[str] | None = None, col_order: Iterable[str] | None = None, row_order: Iterable[str] | None = None, legend: bool = True, - legend_out: Incomplete | None = None, - x_estimator: Incomplete | None = None, - x_bins: Incomplete | None = None, + legend_out: bool | None = None, # deprecated + x_estimator: Callable[[Incomplete], Incomplete] | None = None, + x_bins: int | _Vector | None = None, x_ci: Literal["ci", "sd"] | int | None = "ci", scatter: bool = True, fit_reg: bool = True, @@ -55,27 +58,28 @@ def lmplot( line_kws: dict[str, Any] | None = None, facet_kws: dict[str, Any] | None = None, ) -> FacetGrid: ... +@overload def regplot( - data: pd.DataFrame | None = None, + data: None = None, *, - x: Incomplete | None = None, - y: Incomplete | None = None, - x_estimator: Incomplete | None = None, - x_bins: Incomplete | None = None, + x: _Vector | None = None, + y: _Vector | None = None, + x_estimator: Callable[[Incomplete], Incomplete] | None = None, + x_bins: int | _Vector | None = None, x_ci: Literal["ci", "sd"] | int | None = "ci", scatter: bool = True, fit_reg: bool = True, ci: int | None = 95, n_boot: int = 1000, - units: str | None = None, + units: _Vector | None = None, seed: _Seed | None = None, order: int = 1, logistic: bool = False, lowess: bool = False, robust: bool = False, logx: bool = False, - x_partial: str | None = None, - y_partial: str | None = None, + x_partial: _Vector | None = None, + y_partial: _Vector | None = None, truncate: bool = True, dropna: bool = True, x_jitter: float | None = None, @@ -87,13 +91,65 @@ def regplot( line_kws: dict[str, Any] | None = None, ax: Axes | None = None, ) -> Axes: ... -def residplot( - data: Incomplete | None = None, +@overload +def regplot( + data: pd.DataFrame, *, - x: Incomplete | None = None, - y: Incomplete | None = None, - x_partial: Incomplete | None = None, - y_partial: Incomplete | None = None, + x: str | _Vector | None = None, + y: str | _Vector | None = None, + x_estimator: Callable[[Incomplete], Incomplete] | None = None, + x_bins: int | _Vector | None = None, + x_ci: Literal["ci", "sd"] | int | None = "ci", + scatter: bool = True, + fit_reg: bool = True, + ci: int | None = 95, + n_boot: int = 1000, + units: str | _Vector | None = None, + seed: _Seed | None = None, + order: int = 1, + logistic: bool = False, + lowess: bool = False, + robust: bool = False, + logx: bool = False, + x_partial: str | _Vector | None = None, + y_partial: str | _Vector | None = None, + truncate: bool = True, + dropna: bool = True, + x_jitter: float | None = None, + y_jitter: float | None = None, + label: str | None = None, + color: ColorType | None = None, + marker: str = "o", + scatter_kws: dict[str, Any] | None = None, + line_kws: dict[str, Any] | None = None, + ax: Axes | None = None, +) -> Axes: ... +@overload +def residplot( + data: None = None, + *, + x: _Vector | None = None, + y: _Vector | None = None, + x_partial: _Vector | None = None, + y_partial: _Vector | None = None, + lowess: bool = False, + order: int = 1, + robust: bool = False, + dropna: bool = True, + label: str | None = None, + color: ColorType | None = None, + scatter_kws: dict[str, Any] | None = None, + line_kws: dict[str, Any] | None = None, + ax: Axes | None = None, +) -> Axes: ... +@overload +def residplot( + data: pd.DataFrame, + *, + x: str | _Vector | None = None, + y: str | _Vector | None = None, + x_partial: str | _Vector | None = None, + y_partial: str | _Vector | None = None, lowess: bool = False, order: int = 1, robust: bool = False,