seaborn: complete and fix axisgrid module (#11096)

This commit is contained in:
Ali Hamdan
2023-12-04 12:30:26 +01:00
committed by GitHub
parent fe48f37bf7
commit 0f7241844e
4 changed files with 244 additions and 49 deletions

View File

@@ -3,3 +3,5 @@ seaborn.external.docscrape.ClassDoc.__init__ # stubtest doesn't like ABC class
seaborn.external.docscrape.NumpyDocString.__str__ # weird signature
seaborn(\.regression)?\.lmplot # the `data` argument is required but it defaults to `None` at runtime
seaborn.axisgrid.Grid.tight_layout # the method doesn't really take pos args but runtime has *args

View File

@@ -1,18 +1,23 @@
from _typeshed import Incomplete
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from datetime import date, datetime, timedelta
from typing import Any
from typing import Any, Protocol
from typing_extensions import TypeAlias
from matplotlib.colors import Colormap, Normalize
from numpy import ndarray
from pandas import Index, Series, Timedelta, Timestamp
from pandas import DataFrame, Index, Series, Timedelta, Timestamp
class _SupportsDataFrame(Protocol):
# `__dataframe__` should return pandas.core.interchange.dataframe_protocol.DataFrame
# but this class needs to be defined as a Protocol, not as an ABC.
def __dataframe__(self, nan_as_null: bool = ..., allow_copy: bool = ...) -> Incomplete: ...
ColumnName: TypeAlias = str | bytes | date | datetime | timedelta | bool | complex | Timestamp | Timedelta
Vector: TypeAlias = Series[Any] | Index[Any] | ndarray[Any, Any]
VariableSpec: TypeAlias = ColumnName | Vector | None
VariableSpecList: TypeAlias = list[VariableSpec] | Index[Any] | None
DataSource: TypeAlias = Incomplete
DataSource: TypeAlias = DataFrame | _SupportsDataFrame | Mapping[ColumnName, Incomplete] | None
OrderSpec: TypeAlias = Iterable[str] | None
NormSpec: TypeAlias = tuple[float | None, float | None] | Normalize | None
PaletteSpec: TypeAlias = str | list[Incomplete] | dict[Incomplete, Incomplete] | Colormap | None

View File

@@ -1,28 +1,96 @@
import os
from _typeshed import Incomplete
from collections.abc import Callable, Generator, Iterable, Mapping
from typing import Any, TypeVar
from typing_extensions import Concatenate, Literal, ParamSpec, Self
from typing import IO, Any, TypeVar
from typing_extensions import Concatenate, Literal, ParamSpec, Self, TypeAlias
import numpy as np
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from matplotlib.backend_bases import MouseEvent, RendererBase
from matplotlib.colors import Colormap, Normalize
from matplotlib.figure import Figure
from matplotlib.font_manager import FontProperties
from matplotlib.gridspec import SubplotSpec
from matplotlib.legend import Legend
from matplotlib.patches import Patch
from matplotlib.path import Path as mpl_Path
from matplotlib.patheffects import AbstractPathEffect
from matplotlib.scale import ScaleBase
from matplotlib.text import Text
from matplotlib.typing import ColorType
from numpy.typing import NDArray
from matplotlib.transforms import Bbox, BboxBase, Transform, TransformedPath
from matplotlib.typing import ColorType, LineStyleType, MarkerType
from numpy.typing import ArrayLike, NDArray
from pandas import DataFrame, Series
from ._core.typing import ColumnName, DataSource, _SupportsDataFrame
from .palettes import _RGBColorPalette
from .utils import _Palette
from .utils import _DataSourceWideForm, _Palette, _Vector
__all__ = ["FacetGrid", "PairGrid", "JointGrid", "pairplot", "jointplot"]
_P = ParamSpec("_P")
_R = TypeVar("_R")
_LiteralFont: TypeAlias = Literal["xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large"]
class _BaseGrid:
def set(self, **kwargs: Incomplete) -> Self: ... # **kwargs are passed to `matplotlib.axes.Axes.set`
def set(
self,
*,
# Keywords follow `matplotlib.axes.Axes.set`. Each keyword <KW> corresponds to a `set_<KW>` method
adjustable: Literal["box", "datalim"] = ...,
agg_filter: Callable[[ArrayLike, float], tuple[NDArray[np.floating[Any]], float, float]] | None = ...,
alpha: float | None = ...,
anchor: str | tuple[float, float] = ...,
animated: bool = ...,
aspect: float | Literal["auto", "equal"] = ...,
autoscale_on: bool = ...,
autoscalex_on: bool = ...,
autoscaley_on: bool = ...,
axes_locator: Callable[[Axes, RendererBase], Bbox] = ...,
axisbelow: bool | Literal["line"] = ...,
box_aspect: float | None = ...,
clip_box: BboxBase | None = ...,
clip_on: bool = ...,
clip_path: Patch | mpl_Path | TransformedPath | None = ...,
facecolor: ColorType | None = ...,
frame_on: bool = ...,
gid: str | None = ...,
in_layout: bool = ...,
label: object = ...,
mouseover: bool = ...,
navigate: bool = ...,
path_effects: list[AbstractPathEffect] = ...,
picker: bool | float | Callable[[Artist, MouseEvent], tuple[bool, dict[Any, Any]]] | None = ...,
position: Bbox | tuple[float, float, float, float] = ...,
prop_cycle: Incomplete = ..., # TODO: use cycler.Cycler when cycler gets typed
rasterization_zorder: float | None = ...,
rasterized: bool = ...,
sketch_params: float | None = ...,
snap: bool | None = ...,
subplotspec: SubplotSpec = ...,
title: str = ...,
transform: Transform | None = ...,
url: str | None = ...,
visible: bool = ...,
xbound: float | None | tuple[float | None, float | None] = ...,
xlabel: str = ...,
xlim: float | None | tuple[float | None, float | None] = ...,
xmargin: float = ...,
xscale: str | ScaleBase = ...,
xticklabels: Iterable[str | Text] = ...,
xticks: ArrayLike = ...,
ybound: float | None | tuple[float | None, float | None] = ...,
ylabel: str = ...,
ylim: float | None | tuple[float | None, float | None] = ...,
ymargin: float = ...,
yscale: str | ScaleBase = ...,
yticklabels: Iterable[str | Text] = ...,
yticks: ArrayLike = ...,
zorder: float = ...,
**kwargs: Any,
) -> Self: ...
@property
def fig(self) -> Figure: ...
@property
@@ -30,27 +98,110 @@ class _BaseGrid:
def apply(self, func: Callable[Concatenate[Self, _P], object], *args: _P.args, **kwargs: _P.kwargs) -> Self: ...
def pipe(self, func: Callable[Concatenate[Self, _P], _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
def savefig(
self, *args: Incomplete, **kwargs: Incomplete
) -> None: ... # *args and **kwargs are passed to `matplotlib.figure.Figure.savefig`
self,
# Signature follows `matplotlib.figure.Figure.savefig`
fname: str | os.PathLike[Any] | IO[Any],
*,
transparent: bool | None = None,
dpi: float | Literal["figure"] | None = 96,
facecolor: ColorType | Literal["auto"] | None = "auto",
edgecolor: ColorType | Literal["auto"] | None = "auto",
orientation: Literal["landscape", "portrait"] = "portrait",
format: str | None = None,
bbox_inches: Literal["tight"] | Bbox | None = "tight",
pad_inches: float | Literal["layout"] | None = None,
backend: str | None = None,
**kwargs: Any,
) -> None: ...
class Grid(_BaseGrid):
def __init__(self) -> None: ...
def tight_layout(
self, *args: Incomplete, **kwargs: Incomplete
) -> Self: ... # *args and **kwargs are passed to `matplotlib.figure.Figure.tight_layout`
self,
*,
# Keywords follow `matplotlib.figure.Figure.tight_layout`
pad: float = 1.08,
h_pad: float | None = None,
w_pad: float | None = None,
rect: tuple[float, float, float, float] | None = None,
) -> Self: ...
def add_legend(
self,
legend_data: Mapping[Any, Artist] | None = None, # cannot use precise key type because of invariant Mapping keys
# Cannot use precise key type with union for legend_data because of invariant Mapping keys
legend_data: Mapping[Any, Artist] | None = None,
title: str | None = None,
label_order: list[str] | None = None,
adjust_subtitles: bool = False,
**kwargs: Incomplete, # **kwargs are passed to `matplotlib.figure.Figure.legend`
*,
# Keywords follow `matplotlib.legend.Legend`
loc: str | int | tuple[float, float] | None = None,
numpoints: int | None = None,
markerscale: float | None = None,
markerfirst: bool = True,
reverse: bool = False,
scatterpoints: int | None = None,
scatteryoffsets: Iterable[float] | None = None,
prop: FontProperties | dict[str, Any] | None = None,
fontsize: int | _LiteralFont | None = None,
labelcolor: str | Iterable[str] | None = None,
borderpad: float | None = None,
labelspacing: float | None = None,
handlelength: float | None = None,
handleheight: float | None = None,
handletextpad: float | None = None,
borderaxespad: float | None = None,
columnspacing: float | None = None,
ncols: int = 1,
mode: Literal["expand"] | None = None,
fancybox: bool | None = None,
shadow: bool | dict[str, float] | None = None,
title_fontsize: int | _LiteralFont | None = None,
framealpha: float | None = None,
edgecolor: ColorType | None = None,
facecolor: ColorType | None = None,
bbox_to_anchor: BboxBase | tuple[float, float] | tuple[float, float, float, float] | None = None,
bbox_transform: Transform | None = None,
frameon: bool | None = None,
handler_map: None = None,
title_fontproperties: FontProperties | None = None,
alignment: Literal["center", "left", "right"] = "center",
ncol: int = 1,
draggable: bool = False,
) -> Self: ...
@property
def legend(self) -> Legend | None: ...
def tick_params(
self, axis: Literal["x", "y", "both"] = "both", **kwargs: Incomplete
) -> Self: ... # **kwargs are passed to `matplotlib.axes.Axes.tick_params`
self,
axis: Literal["x", "y", "both"] = "both",
*,
# Keywords follow `matplotlib.axes.Axes.tick_params`
which: Literal["major", "minor", "both"] = "major",
reset: bool = False,
direction: Literal["in", "out", "inout"] = ...,
length: float = ...,
width: float = ...,
color: ColorType = ...,
pad: float = ...,
labelsize: float | str = ...,
labelcolor: ColorType = ...,
labelfontfamily: str = ...,
colors: ColorType = ...,
zorder: float = ...,
bottom: bool = ...,
top: bool = ...,
left: bool = ...,
right: bool = ...,
labelbottom: bool = ...,
labeltop: bool = ...,
labelleft: bool = ...,
labelright: bool = ...,
labelrotation: float = ...,
grid_color: ColorType = ...,
grid_alpha: float = ...,
grid_linewidth: float = ...,
grid_linestyle: str = ...,
**kwargs: Any,
) -> Self: ...
class FacetGrid(Grid):
data: DataFrame
@@ -60,7 +211,7 @@ class FacetGrid(Grid):
hue_kws: dict[str, Any]
def __init__(
self,
data: DataFrame,
data: DataFrame | _SupportsDataFrame,
*,
row: str | None = None,
col: str | None = None,
@@ -88,10 +239,10 @@ class FacetGrid(Grid):
def map(self, func: Callable[..., object], *args: str, **kwargs: Any) -> Self: ...
def map_dataframe(self, func: Callable[..., object], *args: str, **kwargs: Any) -> Self: ...
def facet_axis(self, row_i: int, col_j: int, modify_state: bool = True) -> Axes: ...
# `despine` should be kept roughly in line with `seaborn.utils.despine`
def despine(
self,
*,
fig: Figure | None = None,
ax: Axes | None = None,
top: bool = True,
right: bool = True,
@@ -111,7 +262,13 @@ class FacetGrid(Grid):
self, template: str | None = None, row_template: str | None = None, col_template: str | None = None, **kwargs: Any
) -> Self: ...
def refline(
self, *, x: float | None = None, y: float | None = None, color: ColorType = ".5", linestyle: str = "--", **line_kws: Any
self,
*,
x: float | None = None,
y: float | None = None,
color: ColorType = ".5",
linestyle: LineStyleType = "--",
**line_kws: Any,
) -> Self: ...
@property
def axes(self) -> NDArray[Incomplete]: ... # array of `Axes`
@@ -127,15 +284,15 @@ class PairGrid(Grid):
axes: NDArray[Incomplete] # two-dimensional array of `Axes`
data: DataFrame
diag_sharey: bool
diag_vars: NDArray[Incomplete] | None # array of `str`
diag_axes: NDArray[Incomplete] | None # array of `Axes`
diag_vars: list[str] | None
diag_axes: list[Axes] | None
hue_names: list[str]
hue_vals: Series[Incomplete]
hue_vals: Series[Any]
hue_kws: dict[str, Any]
palette: _RGBColorPalette
def __init__(
self,
data: DataFrame,
data: DataFrame | _SupportsDataFrame,
*,
hue: str | None = None,
vars: Iterable[str] | None = None,
@@ -162,25 +319,25 @@ class JointGrid(_BaseGrid):
ax_joint: Axes
ax_marg_x: Axes
ax_marg_y: Axes
x: Series[Incomplete]
y: Series[Incomplete]
hue: Series[Incomplete]
x: Series[Any]
y: Series[Any]
hue: Series[Any]
def __init__(
self,
data: Incomplete | None = None,
data: DataSource | _DataSourceWideForm | None = None,
*,
x: Incomplete | None = None,
y: Incomplete | None = None,
hue: Incomplete | None = None,
x: ColumnName | _Vector | None = None,
y: ColumnName | _Vector | None = None,
hue: ColumnName | _Vector | None = None,
height: float = 6,
ratio: float = 5,
space: float = 0.2,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[str] | None = None,
hue_norm: Incomplete | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
dropna: bool = False,
xlim: Incomplete | None = None,
ylim: Incomplete | None = None,
xlim: float | tuple[float, float] | None = None,
ylim: float | tuple[float, float] | None = None,
marginal_ticks: bool = False,
) -> None: ...
def plot(self, joint_func: Callable[..., object], marginal_func: Callable[..., object], **kwargs: Any) -> Self: ...
@@ -194,7 +351,7 @@ class JointGrid(_BaseGrid):
joint: bool = True,
marginal: bool = True,
color: ColorType = ".5",
linestyle: str = "--",
linestyle: LineStyleType = "--",
**line_kws: Any,
) -> Self: ...
def set_axis_labels(self, xlabel: str = "", ylabel: str = "", **kwargs: Any) -> Self: ...
@@ -210,7 +367,7 @@ def pairplot(
y_vars: Iterable[str] | str | None = None,
kind: Literal["scatter", "kde", "hist", "reg"] = "scatter",
diag_kind: Literal["auto", "hist", "kde"] | None = "auto",
markers: Incomplete | None = None,
markers: MarkerType | list[MarkerType] | None = None,
height: float = 2.5,
aspect: float = 1,
corner: bool = False,
@@ -221,22 +378,22 @@ def pairplot(
size: float | None = None, # deprecated
) -> PairGrid: ...
def jointplot(
data: Incomplete | None = None,
data: DataSource | _DataSourceWideForm | None = None,
*,
x: Incomplete | None = None,
y: Incomplete | None = None,
hue: Incomplete | None = None,
kind: str = "scatter", # ideally Literal["scatter", "kde", "hist", "hex", "reg", "resid"] but it is checked with startswith
x: ColumnName | _Vector | None = None,
y: ColumnName | _Vector | None = None,
hue: ColumnName | _Vector | None = None,
kind: Literal["scatter", "kde", "hist", "hex", "reg", "resid"] = "scatter",
height: float = 6,
ratio: float = 5,
space: float = 0.2,
dropna: bool = False,
xlim: Incomplete | None = None,
ylim: Incomplete | None = None,
xlim: float | tuple[float, float] | None = None,
ylim: float | tuple[float, float] | None = None,
color: ColorType | None = None,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[str] | None = None,
hue_norm: Incomplete | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
marginal_ticks: bool = False,
joint_kws: dict[str, Any] | None = None,
marginal_kws: dict[str, Any] | None = None,

View File

@@ -1,9 +1,11 @@
import datetime as dt
from _typeshed import Incomplete, SupportsGetItem
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Any, TypeVar, overload
from typing_extensions import Literal, SupportsIndex, TypeAlias
import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.legend import Legend
@@ -34,6 +36,35 @@ _Legend: TypeAlias = Literal["auto", "brief", "full"] | bool # noqa: Y047
_LogScale: TypeAlias = bool | float | tuple[bool | float, bool | float] # noqa: Y047
_Palette: TypeAlias = str | Sequence[ColorType] | dict[Incomplete, ColorType] # noqa: Y047
_Seed: TypeAlias = int | np.random.Generator | np.random.RandomState # noqa: Y047
_Scalar: TypeAlias = (
# numeric
float
| complex
| np.number[Any]
# categorical
| bool
| str
| bytes
| None
# dates
| dt.date
| dt.datetime
| dt.timedelta
| pd.Timestamp
| pd.Timedelta
)
_Vector: TypeAlias = Iterable[_Scalar]
_DataSourceWideForm: TypeAlias = ( # noqa: Y047
# Mapping of keys to "convertible to pd.Series" vectors
Mapping[Any, _Vector]
# Sequence of "convertible to pd.Series" vectors
| Sequence[_Vector]
# A "convertible to pd.DataFrame" table
| Mapping[Any, Mapping[_Scalar, _Scalar]]
| NDArray[Any]
# Flat "convertible to pd.Series" vector of scalars
| Sequence[_Scalar]
)
DATASET_SOURCE: str
DATASET_NAMES_URL: str
@@ -48,7 +79,7 @@ def axlabel(xlabel: str, ylabel: str, **kwargs: Any) -> None: ... # deprecated
def remove_na(vector: _VectorT) -> _VectorT: ...
def get_color_cycle() -> list[str]: ...
# Please modify `seaborn.axisgrid.FacetGrid.despine` when modifying despine here.
# `despine` should be kept roughly in line with `seaborn.axisgrid.FacetGrid.despine`
def despine(
fig: Figure | None = None,
ax: Axes | None = None,