diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 41679520930..b285562eec1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -17,6 +17,9 @@ New Features - Added ``inherit='all_coords'`` option to :py:meth:`DataTree.to_dataset` to inherit all parent coordinates, not just indexed ones (:issue:`10812`, :pull:`11230`). By `Alfonso Ladino `_. +- Support ``col_wrap='auto'`` in plots that will wrap the grid to be as square + as possible (:pull:`11266`). + By `Michael Niklas `_. Breaking Changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py index 9db4ae4e3f7..2b4c28a9027 100644 --- a/xarray/plot/accessor.py +++ b/xarray/plot/accessor.py @@ -192,7 +192,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -232,7 +232,7 @@ def scatter( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -272,7 +272,7 @@ def scatter( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -311,7 +311,7 @@ def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -350,7 +350,7 @@ def imshow( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -389,7 +389,7 @@ def imshow( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -432,7 +432,7 @@ def contour( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -471,7 +471,7 @@ def contour( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -510,7 +510,7 @@ def contour( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -553,7 +553,7 @@ def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -592,7 +592,7 @@ def contourf( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -631,7 +631,7 @@ def contourf( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -674,7 +674,7 @@ def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -713,7 +713,7 @@ def pcolormesh( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -752,7 +752,7 @@ def pcolormesh( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -795,7 +795,7 @@ def surface( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -834,7 +834,7 @@ def surface( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -873,7 +873,7 @@ def surface( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -940,7 +940,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -980,7 +980,7 @@ def scatter( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -1020,7 +1020,7 @@ def scatter( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -1062,7 +1062,7 @@ def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -1098,7 +1098,7 @@ def quiver( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -1134,7 +1134,7 @@ def quiver( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -1174,7 +1174,7 @@ def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -1210,7 +1210,7 @@ def streamplot( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -1246,7 +1246,7 @@ def streamplot( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index e06a0b7187d..9fead9934d6 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -228,7 +228,7 @@ def plot( *, row: Hashable | None = None, col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, ax: Axes | None = None, hue: Hashable | None = None, subplot_kws: dict[str, Any] | None = None, @@ -255,8 +255,10 @@ def plot( If passed, make row faceted plots on this dimension name. col : Hashable or None, optional If passed, make column faceted plots on this dimension name. - col_wrap : int or None, optional - Use together with ``col`` to wrap faceted plots. + col_wrap : int, None or "auto", optional + "Wrap" the grid for the column variable after this number of columns, + adding rows if ``col_wrap`` is less than the number of facets. + If "auto" align the grid to the figsize or keep it as square as possible. ax : matplotlib axes object, optional Axes on which to plot. By default, use the current axes. Mutually exclusive with ``size``, ``figsize`` and facets. @@ -740,8 +742,10 @@ def _plot1d(plotfunc): If passed, make row faceted plots on this dimension name. col : Hashable, optional If passed, make column faceted plots on this dimension name. -col_wrap : int, optional - Use together with ``col`` to wrap faceted plots +col_wrap : int, None or "auto", optional + "Wrap" the grid for the column variable after this number of columns, + adding rows if ``col_wrap`` is less than the number of facets. + If "auto" align the grid to the figsize or keep it as square as possible. ax : matplotlib axes object, optional If None, uses the current axis. Not applicable when using facets. figsize : Iterable[float] or None, optional @@ -849,7 +853,7 @@ def newplotfunc( linewidth: Hashable | None = None, row: Hashable | None = None, col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, @@ -1130,7 +1134,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -1171,7 +1175,7 @@ def scatter( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -1212,7 +1216,7 @@ def scatter( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -1308,8 +1312,10 @@ def _plot2d(plotfunc): If passed, make row faceted plots on this dimension name. col : Hashable or None, optional If passed, make column faceted plots on this dimension name. -col_wrap : int, optional - Use together with ``col`` to wrap faceted plots. +col_wrap : int, None or "auto", optional + "Wrap" the grid for the column variable after this number of columns, + adding rows if ``col_wrap`` is less than the number of facets. + If "auto" align the grid to the figsize or keep it as square as possible. xincrease : None, True, or False, optional Should the values on the *x* axis be increasing from left to right? If ``None``, use the default for the Matplotlib function. @@ -1420,7 +1426,7 @@ def newplotfunc( ax: Axes | None = None, row: Hashable | None = None, col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -1675,7 +1681,7 @@ def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -1715,7 +1721,7 @@ def imshow( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -1755,7 +1761,7 @@ def imshow( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -1892,7 +1898,7 @@ def contour( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -1932,7 +1938,7 @@ def contour( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -1972,7 +1978,7 @@ def contour( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2025,7 +2031,7 @@ def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2065,7 +2071,7 @@ def contourf( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2105,7 +2111,7 @@ def contourf( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2158,7 +2164,7 @@ def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2198,7 +2204,7 @@ def pcolormesh( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2238,7 +2244,7 @@ def pcolormesh( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2342,7 +2348,7 @@ def surface( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2382,7 +2388,7 @@ def surface( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, @@ -2422,7 +2428,7 @@ def surface( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_colorbar: bool | None = None, diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 857e9170508..bc51d1eee80 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -4,7 +4,7 @@ import inspect import warnings from collections.abc import Callable, Hashable, Iterable -from typing import TYPE_CHECKING, Any, TypeVar, overload +from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload from xarray.plot import dataarray_plot from xarray.plot.facetgrid import _easy_facetgrid @@ -65,8 +65,10 @@ def _dsplot(plotfunc): If passed, make row faceted plots on this dimension name. col : Hashable or None, optional If passed, make column faceted plots on this dimension name. -col_wrap : int, optional - Use together with ``col`` to wrap faceted plots. +col_wrap : int, None or "auto", optional + "Wrap" the grid for the column variable after this number of columns, + adding rows if ``col_wrap`` is less than the number of facets. + If "auto" align the grid to the figsize or keep it as square as possible. ax : matplotlib axes object or None, optional If ``None``, use the current axes. Not applicable when using facets. figsize : Iterable[float] or None, optional @@ -169,7 +171,7 @@ def newplotfunc( hue_style: HueStyleOptions = None, row: Hashable | None = None, col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, @@ -336,7 +338,7 @@ def quiver( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -373,7 +375,7 @@ def quiver( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -410,7 +412,7 @@ def quiver( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -487,7 +489,7 @@ def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -524,7 +526,7 @@ def streamplot( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -561,7 +563,7 @@ def streamplot( ax: Axes | None = None, figsize: Iterable[float] | None = None, size: float | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, aspect: AspectOptions = None, @@ -767,7 +769,7 @@ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( ax: Axes | None = None, row: None = None, # no wrap -> primitive col: None = None, # no wrap -> primitive - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -808,7 +810,7 @@ def scatter( ax: Axes | None = None, row: Hashable | None = None, col: Hashable, # wrap -> FacetGrid - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -849,7 +851,7 @@ def scatter( ax: Axes | None = None, row: Hashable, # wrap -> FacetGrid col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, @@ -890,7 +892,7 @@ def scatter( ax: Axes | None = None, row: Hashable | None = None, col: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, xincrease: bool | None = True, yincrease: bool | None = True, add_legend: bool | None = None, diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 5da382c1177..2bdd11a9f59 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -56,6 +56,31 @@ def _nicetitle(coord, value, maxchar, template): return title +def _auto_grid( + nfacet: int, figsize: tuple[float, ...] | None, aspect: float +) -> tuple[int, int]: + + # Try to align the grid to the figsize. If figsize is unknown it gets + # computed from the grid, so lets keep it as square as possible + faspect = 1 if figsize is None else figsize[0] / figsize[1] + + # Only wrap if > 3 images + if nfacet <= 3: + return nfacet, 1 + + # Geometric ideal case + ncol = int(np.ceil(np.sqrt(nfacet * faspect / aspect))) + ncol = max(1, min(ncol, nfacet)) + nrow = int(np.ceil(nfacet / ncol)) + + # Reduce columns as long as we don't need more rows + # This eliminates empty slots in the last row if aspect < 1 + while ncol > 1 and (ncol - 1) * nrow >= nfacet: + ncol -= 1 + + return ncol, nrow + + T_FacetGrid = TypeVar("T_FacetGrid", bound="FacetGrid") @@ -114,7 +139,7 @@ class FacetGrid(Generic[T_DataArrayOrSet]): _row_var: Hashable | None _ncol: int _col_var: Hashable | None - _col_wrap: int | None + _col_wrap: int | Literal["auto"] | None row_labels: list[Annotation | None] col_labels: list[Annotation | None] _x_var: None @@ -129,7 +154,7 @@ def __init__( data: T_DataArrayOrSet, col: Hashable | None = None, row: Hashable | None = None, - col_wrap: int | None = None, + col_wrap: int | Literal["auto"] | None = None, sharex: bool = True, sharey: bool = True, figsize: Iterable[float] | None = None, @@ -142,15 +167,16 @@ def __init__( ---------- data : DataArray or Dataset DataArray or Dataset to be plotted. - row, col : str + row, col : hashable or None, optional Dimension names that define subsets of the data, which will be drawn on separate facets in the grid. - col_wrap : int, optional - "Wrap" the grid the for the column variable after this number of columns, + col_wrap : int, None or "auto", optional + "Wrap" the grid for the column variable after this number of columns, adding rows if ``col_wrap`` is less than the number of facets. - sharex : bool, optional + If "auto" align the grid to the figsize or keep it as square as possible. + sharex : bool, default: True If true, the facets will share *x* axes. - sharey : bool, optional + sharey : bool, default: True If true, the facets will share *y* axes. figsize : Iterable of float or None, optional A tuple (width, height) of the figure in inches. @@ -163,7 +189,6 @@ def __init__( subplot_kws : dict, optional Dictionary of keyword arguments for Matplotlib subplots (:py:func:`matplotlib.pyplot.subplots`). - """ import matplotlib.pyplot as plt @@ -195,18 +220,20 @@ def __init__( else: raise ValueError("Pass a coordinate name as an argument for row or col") + # exhaust generators + figsize = None if figsize is None else tuple(figsize) + # Compute grid shape if single_group: nfacet = len(data[single_group]) - if col: - # idea - could add heuristic for nice shapes like 3x4 - ncol = nfacet - if row: - ncol = 1 - if col_wrap is not None: - # Overrides previous settings + if col_wrap == "auto": + ncol, nrow = _auto_grid(nfacet, figsize, aspect) + elif col_wrap is None: + ncol = nfacet if col else 1 + nrow = int(np.ceil(nfacet / ncol)) + else: ncol = col_wrap - nrow = int(np.ceil(nfacet / ncol)) + nrow = int(np.ceil(nfacet / ncol)) # Set the subplot kwargs subplot_kws = {} if subplot_kws is None else subplot_kws diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 5980d449dbb..cf8d0e2999e 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1668,6 +1668,34 @@ def test_convenient_facetgrid_4d(self) -> None: for ax in g.axs.flat: assert ax.has_data() + @pytest.mark.parametrize( + ["n", "figsize", "aspect", "expected_shape"], + [ + pytest.param(1, None, 1, [1, 1], id="1"), + pytest.param(3, None, 1, [1, 3], id="3"), # <4 should not be wrapped + pytest.param(6, None, 1, [2, 3], id="6"), + pytest.param(8, None, 1, [3, 3], id="8"), + pytest.param(8, [10, 5], 1, [2, 4], id="8-figaspect=2"), + pytest.param(8, [5, 10], 1, [4, 2], id="8-figaspect=0.5"), + pytest.param(8, None, 4, [4, 2], id="8-aspect=4"), + pytest.param(8, None, 0.25, [2, 4], id="8-aspect=0.25"), + ], + ) + def test_facetgrid_col_wrap_auto( + self, + n: int, + figsize: None | tuple[int, int], + aspect: int, + expected_shape: tuple[int, int], + ) -> None: + a = easy_array((10, 15, n)) + d = DataArray(a, dims=["y", "x", "z"]) + g = self.plotfunc( + d, x="x", y="y", col="z", col_wrap="auto", figsize=figsize, aspect=aspect + ) + + assert_array_equal(g.axs.shape, expected_shape) + @pytest.mark.filterwarnings("ignore:This figure includes") def test_facetgrid_map_only_appends_mappables(self) -> None: a = easy_array((10, 15, 2, 3))