seaborn: Update and complete the relational module (#11214)

This commit is contained in:
Ali Hamdan
2024-01-14 13:51:31 +01:00
committed by GitHub
parent 8e8204b83f
commit 403d43b13c
9 changed files with 100 additions and 84 deletions

View File

@@ -2,6 +2,4 @@ seaborn._core.scales.(Pipeline|TransFuncs) # aliases defined in `if TYPE_CHECKI
seaborn.external.docscrape.ClassDoc.__init__ # stubtest doesn't like ABC class as default value
seaborn.external.docscrape.NumpyDocString.__str__ # weird signature
seaborn(\.regression)?\.lmplot # the `data` argument is required but it defaults to `None` at runtime
seaborn.axisgrid.Grid.tight_layout # the method doesn't really take pos args but runtime has *args

View File

@@ -1,4 +1,4 @@
version = "0.13.0"
version = "0.13.1"
# Requires a version of numpy and matplotlib with a `py.typed` file
requires = ["matplotlib>=3.8", "numpy>=1.20", "pandas-stubs"]
# matplotlib>=3.8 requires Python >=3.9

View File

@@ -98,8 +98,14 @@ class Plot:
def share(self, **shares: bool | str) -> Plot: ...
def limit(self, **limits: tuple[Any, Any]) -> Plot: ...
def label(self, *, title: str | None = None, legend: str | None = None, **variables: str | Callable[[str], str]) -> Plot: ...
def layout(self, *, size: tuple[float, float] | Default = ..., engine: str | None | Default = ...) -> Plot: ...
def theme(self, *args: dict[str, Any]) -> Plot: ...
def layout(
self,
*,
size: tuple[float, float] | Default = ...,
engine: str | None | Default = ...,
extent: tuple[float, float, float, float] | Default = ...,
) -> Plot: ...
def theme(self, __config: dict[str, Any]) -> Plot: ...
# Same signature as Plotter.save
def save(
self,

View File

@@ -1,14 +1,15 @@
from _typeshed import Incomplete
from collections.abc import Iterable, Mapping
from datetime import date, datetime, timedelta
from typing import Any, Protocol
from typing import Any, Protocol, type_check_only
from typing_extensions import TypeAlias
from matplotlib.colors import Colormap, Normalize
from numpy import ndarray
from pandas import DataFrame, Index, Series, Timedelta, Timestamp
class _SupportsDataFrame(Protocol):
@type_check_only
class SupportsDataFrame(Protocol):
# `__dataframe__` should return pandas.core.interchange.dataframe_protocol.DataFrame
# but this class needs to be defined as a Protocol, not as an ABC.
def __dataframe__(self, nan_as_null: bool = ..., allow_copy: bool = ...) -> Incomplete: ...
@@ -17,7 +18,7 @@ ColumnName: TypeAlias = str | bytes | date | datetime | timedelta | bool | compl
Vector: TypeAlias = Series[Any] | Index[Any] | ndarray[Any, Any]
VariableSpec: TypeAlias = ColumnName | Vector | None
VariableSpecList: TypeAlias = list[VariableSpec] | Index[Any] | None
DataSource: TypeAlias = DataFrame | _SupportsDataFrame | Mapping[ColumnName, Incomplete] | None
DataSource: TypeAlias = DataFrame | SupportsDataFrame | Mapping[Any, Incomplete] | None
OrderSpec: TypeAlias = Iterable[str] | None
NormSpec: TypeAlias = tuple[float | None, float | None] | Normalize | None
PaletteSpec: TypeAlias = str | list[Incomplete] | dict[Incomplete, Incomplete] | Colormap | None

View File

@@ -8,7 +8,7 @@ import numpy as np
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.backend_bases import MouseEvent, RendererBase
from matplotlib.colors import Colormap, Normalize
from matplotlib.colors import Colormap
from matplotlib.figure import Figure
from matplotlib.font_manager import FontProperties
from matplotlib.gridspec import SubplotSpec
@@ -23,7 +23,7 @@ from matplotlib.typing import ColorType, LineStyleType, MarkerType
from numpy.typing import ArrayLike, NDArray
from pandas import DataFrame, Series
from ._core.typing import ColumnName, DataSource, _SupportsDataFrame
from ._core.typing import ColumnName, DataSource, NormSpec, SupportsDataFrame
from .palettes import _RGBColorPalette
from .utils import _DataSourceWideForm, _Palette, _Vector
@@ -155,7 +155,7 @@ class Grid(_BaseGrid):
ncols: int = 1,
mode: Literal["expand"] | None = None,
fancybox: bool | None = None,
shadow: bool | dict[str, float] | None = None,
shadow: bool | dict[str, int] | dict[str, float] | None = None,
title_fontsize: int | _LiteralFont | None = None,
framealpha: float | None = None,
edgecolor: ColorType | None = None,
@@ -212,7 +212,7 @@ class FacetGrid(Grid):
hue_kws: dict[str, Any]
def __init__(
self,
data: DataFrame | _SupportsDataFrame,
data: DataFrame | SupportsDataFrame,
*,
row: str | None = None,
col: str | None = None,
@@ -293,7 +293,7 @@ class PairGrid(Grid):
palette: _RGBColorPalette
def __init__(
self,
data: DataFrame | _SupportsDataFrame,
data: DataFrame | SupportsDataFrame,
*,
hue: str | None = None,
vars: Iterable[str] | None = None,
@@ -335,7 +335,7 @@ class JointGrid(_BaseGrid):
space: float = 0.2,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
dropna: bool = False,
xlim: float | tuple[float, float] | None = None,
ylim: float | tuple[float, float] | None = None,
@@ -394,7 +394,7 @@ def jointplot(
color: ColorType | None = None,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
marginal_ticks: bool = False,
joint_kws: dict[str, Any] | None = None,
marginal_kws: dict[str, Any] | None = None,

View File

@@ -3,10 +3,9 @@ from collections.abc import Callable, Iterable
from typing import Any, Literal
from matplotlib.axes import Axes
from matplotlib.colors import Normalize
from matplotlib.typing import ColorType, LineStyleType, MarkerType
from ._core.typing import ColumnName, DataSource, Default
from ._core.typing import ColumnName, DataSource, Default, NormSpec
from .axisgrid import FacetGrid
from .external.kde import _BwMethodType
from .utils import _DataSourceWideForm, _ErrorBar, _Estimator, _Legend, _LogScale, _Palette, _Seed, _Vector
@@ -33,7 +32,7 @@ def boxplot(
linecolor: ColorType = "auto",
linewidth: float | None = None,
fliersize: float | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
native_scale: bool = False,
log_scale: _LogScale | None = None,
formatter: Callable[[Any], str] | None = None,
@@ -67,7 +66,7 @@ def violinplot(
bw_adjust: float = 1,
density_norm: Literal["area", "count", "width"] = "area",
common_norm: bool | None = False,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
formatter: Callable[[Any], str] | None = None,
log_scale: _LogScale | None = None,
native_scale: bool = False,
@@ -102,7 +101,7 @@ def boxenplot(
outlier_prop: float = 0.007,
trust_alpha: float = 0.05,
showfliers: bool = True,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
log_scale: _LogScale | None = None,
native_scale: bool = False,
formatter: Callable[[Any], str] | None = None,
@@ -130,7 +129,7 @@ def stripplot(
size: float = 5,
edgecolor: ColorType | Default = ...,
linewidth: float = 0,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
log_scale: _LogScale | None = None,
native_scale: bool = False,
formatter: Callable[[Any], str] | None = None,
@@ -153,7 +152,7 @@ def swarmplot(
size: float = 5,
edgecolor: ColorType | None = None,
linewidth: float = 0,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
log_scale: _LogScale | None = None,
native_scale: bool = False,
formatter: Callable[[Any], str] | None = None,
@@ -174,13 +173,14 @@ def barplot(
errorbar: _ErrorBar | None = ("ci", 95),
n_boot: int = 1000,
units: ColumnName | _Vector | None = None,
weights: ColumnName | _Vector | None = None,
seed: _Seed | None = None,
orient: Literal["v", "h", "x", "y"] | None = None,
color: ColorType | None = None,
palette: _Palette | None = None,
saturation: float = 0.75,
fill: bool = True,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
width: float = 0.8,
dodge: bool | Literal["auto"] = "auto",
gap: float = 0,
@@ -208,10 +208,11 @@ def pointplot(
errorbar: _ErrorBar | None = ("ci", 95),
n_boot: int = 1000,
units: ColumnName | _Vector | None = None,
weights: ColumnName | _Vector | None = None,
seed: _Seed | None = None,
color: ColorType | None = None,
palette: _Palette | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
markers: MarkerType | list[MarkerType] | Default = ...,
linestyles: LineStyleType | list[LineStyleType] | Default = ...,
dodge: bool = False,
@@ -242,7 +243,7 @@ def countplot(
palette: _Palette | None = None,
saturation: float = 0.75,
fill: bool = True,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
stat: Literal["count", "percent", "proportion", "probability"] = "count",
width: float = 0.8,
dodge: bool | Literal["auto"] = "auto",
@@ -267,6 +268,7 @@ def catplot(
errorbar: _ErrorBar | None = ("ci", 95),
n_boot: int = 1000,
units: ColumnName | _Vector | None = None,
weights: ColumnName | _Vector | None = None,
seed: _Seed | None = None,
order: Iterable[ColumnName] | None = None,
hue_order: Iterable[ColumnName] | None = None,
@@ -281,7 +283,7 @@ def catplot(
orient: Literal["v", "h", "x", "y"] | None = None,
color: ColorType | None = None,
palette: _Palette | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
legend: _Legend = "auto",
legend_out: bool = True,
sharex: bool = True,

View File

@@ -3,11 +3,11 @@ from typing import Any, Literal, Protocol, TypeVar
from typing_extensions import TypeAlias, deprecated
from matplotlib.axes import Axes
from matplotlib.colors import Colormap, Normalize
from matplotlib.colors import Colormap
from matplotlib.typing import ColorType
from numpy.typing import ArrayLike
from ._core.typing import ColumnName, DataSource
from ._core.typing import ColumnName, DataSource, NormSpec
from .axisgrid import FacetGrid
from .external.kde import _BwMethodType
from .utils import _DataSourceWideForm, _LogScale, _Palette, _Vector
@@ -51,7 +51,7 @@ def histplot(
cbar_kws: dict[str, Any] | None = None,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
color: ColorType | None = None,
log_scale: _LogScale | None = None,
legend: bool = True,
@@ -67,7 +67,7 @@ def kdeplot(
weights: ColumnName | _Vector | None = None,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
color: ColorType | None = None,
fill: bool | None = None,
multiple: Literal["layer", "stack", "fill"] = "layer",
@@ -101,7 +101,7 @@ def ecdfplot(
complementary: bool = False,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
log_scale: _LogScale | None = None,
legend: bool = True,
ax: Axes | None = None,
@@ -117,7 +117,7 @@ def rugplot(
expand_margins: bool = True,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
legend: bool = True,
ax: Axes | None = None,
**kwargs: Any,
@@ -138,7 +138,7 @@ def displot(
legend: bool = True,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
hue_norm: NormSpec = None,
color: ColorType | None = None,
col_wrap: int | None = None,
row_order: Iterable[ColumnName] | None = None,

View File

@@ -1,36 +1,42 @@
from _typeshed import Incomplete
from collections.abc import Iterable
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal
from typing_extensions import TypeAlias
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from matplotlib.typing import MarkerType
from ._core.typing import ColumnName, DataSource, NormSpec
from .axisgrid import FacetGrid
from .utils import _ErrorBar, _Estimator, _Legend, _Palette, _Seed
from .utils import _DataSourceWideForm, _ErrorBar, _Estimator, _Legend, _Palette, _Seed, _Vector
__all__ = ["relplot", "scatterplot", "lineplot"]
_Sizes: TypeAlias = list[float] | dict[str, float] | tuple[float, float]
_Sizes: TypeAlias = list[int] | list[float] | dict[str, int] | dict[str, float] | tuple[float, float]
_DashType: TypeAlias = tuple[None, None] | Sequence[float] # See matplotlib.lines.Line2D.set_dashes
# "dashes" and "markers" require dict but we use mapping to avoid long unions because dict is invariant in its value type
_Dashes: TypeAlias = bool | Sequence[_DashType] | Mapping[Any, _DashType]
_Markers: TypeAlias = bool | Sequence[MarkerType] | Mapping[Any, MarkerType]
def lineplot(
data: Incomplete | None = None,
data: DataSource | _DataSourceWideForm | None = None,
*,
x: Incomplete | None = None,
y: Incomplete | None = None,
hue: Incomplete | None = None,
size: Incomplete | None = None,
style: Incomplete | None = None,
units: Incomplete | None = None,
x: ColumnName | _Vector | None = None,
y: ColumnName | _Vector | None = None,
hue: ColumnName | _Vector | None = None,
size: ColumnName | _Vector | None = None,
style: ColumnName | _Vector | None = None,
units: ColumnName | _Vector | None = None,
weights: ColumnName | _Vector | None = None,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[Any] | None = None,
hue_norm: Incomplete | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: NormSpec = None,
sizes: _Sizes | None = None,
size_order: Iterable[Any] | None = None,
size_norm: Incomplete | None = None,
dashes: bool | list[Incomplete] | dict[str, Incomplete] = True,
markers: Incomplete | None = None,
style_order: Iterable[Any] | None = None,
size_order: Iterable[ColumnName] | None = None,
size_norm: NormSpec = None,
dashes: _Dashes | None = True,
markers: _Markers | None = None,
style_order: Iterable[ColumnName] | None = None,
estimator: _Estimator | None = "mean",
errorbar: _ErrorBar | None = ("ci", 95),
n_boot: int = 1000,
@@ -45,48 +51,49 @@ def lineplot(
**kwargs: Any,
) -> Axes: ...
def scatterplot(
data: Incomplete | None = None,
data: DataSource | _DataSourceWideForm | None = None,
*,
x: Incomplete | None = None,
y: Incomplete | None = None,
hue: Incomplete | None = None,
size: Incomplete | None = None,
style: Incomplete | None = None,
x: ColumnName | _Vector | None = None,
y: ColumnName | _Vector | None = None,
hue: ColumnName | _Vector | None = None,
size: ColumnName | _Vector | None = None,
style: ColumnName | _Vector | None = None,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[Any] | None = None,
hue_norm: Incomplete | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: NormSpec = None,
sizes: _Sizes | None = None,
size_order: Iterable[Any] | None = None,
size_norm: Incomplete | None = None,
markers: Incomplete = True,
style_order: Iterable[Any] | None = None,
size_order: Iterable[ColumnName] | None = None,
size_norm: NormSpec = None,
markers: _Markers | None = True,
style_order: Iterable[ColumnName] | None = None,
legend: _Legend = "auto",
ax: Axes | None = None,
**kwargs: Any,
) -> Axes: ...
def relplot(
data: Incomplete | None = None,
data: DataSource | _DataSourceWideForm | None = None,
*,
x: Incomplete | None = None,
y: Incomplete | None = None,
hue: Incomplete | None = None,
size: Incomplete | None = None,
style: Incomplete | None = None,
units: Incomplete | None = None,
row: Incomplete | None = None,
col: Incomplete | None = None,
x: ColumnName | _Vector | None = None,
y: ColumnName | _Vector | None = None,
hue: ColumnName | _Vector | None = None,
size: ColumnName | _Vector | None = None,
style: ColumnName | _Vector | None = None,
units: ColumnName | _Vector | None = None,
weights: ColumnName | _Vector | None = None,
row: ColumnName | _Vector | None = None,
col: ColumnName | _Vector | None = None,
col_wrap: int | None = None,
row_order: Iterable[Any] | None = None,
col_order: Iterable[Any] | None = None,
row_order: Iterable[ColumnName] | None = None,
col_order: Iterable[ColumnName] | None = None,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[Any] | None = None,
hue_norm: Incomplete | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: NormSpec = None,
sizes: _Sizes | None = None,
size_order: Iterable[Any] | None = None,
size_norm: Incomplete | None = None,
markers: Incomplete | None = None,
dashes: Incomplete | None = None,
style_order: Iterable[Any] | None = None,
size_order: Iterable[ColumnName] | None = None,
size_norm: NormSpec = None,
markers: _Markers | None = None,
dashes: _Dashes | None = None,
style_order: Iterable[ColumnName] | None = None,
legend: _Legend = "auto",
kind: Literal["scatter", "line"] = "scatter",
height: float = 5,

View File

@@ -14,7 +14,8 @@ from matplotlib.ticker import Locator
from matplotlib.typing import ColorType
from numpy.typing import ArrayLike, NDArray
from pandas import DataFrame
from seaborn.axisgrid import Grid
from .axisgrid import Grid
__all__ = [
"desaturate",
@@ -34,7 +35,8 @@ _ErrorBar: TypeAlias = str | tuple[str, float] | Callable[[Iterable[float]], tup
_Estimator: TypeAlias = str | Callable[..., Incomplete] # noqa: Y047
_Legend: TypeAlias = Literal["auto", "brief", "full"] | bool # noqa: Y047
_LogScale: TypeAlias = bool | float | tuple[bool | float, bool | float] # noqa: Y047
_Palette: TypeAlias = str | Sequence[ColorType] | dict[Incomplete, ColorType] # noqa: Y047
# `palette` requires dict but we use mapping to avoid a very long union because dict is invariant in its value type
_Palette: TypeAlias = str | Sequence[ColorType] | Mapping[Any, ColorType] # noqa: Y047
_Seed: TypeAlias = int | np.random.Generator | np.random.RandomState # noqa: Y047
_Scalar: TypeAlias = (
# numeric
@@ -60,7 +62,7 @@ _DataSourceWideForm: TypeAlias = ( # noqa: Y047
# Sequence of "convertible to pd.Series" vectors
| Sequence[_Vector]
# A "convertible to pd.DataFrame" table
| Mapping[Any, Mapping[_Scalar, _Scalar]]
| Mapping[Any, Mapping[Any, _Scalar]]
| NDArray[Any]
# Flat "convertible to pd.Series" vector of scalars
| Sequence[_Scalar]