diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index e06a0b7187d..9200557a757 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -515,7 +515,43 @@ def line( _ensure_plottable(xplt_val, yplt_val) - primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) + # When plotting 2D data with hue, group lines by hue value so that + # lines sharing the same hue are drawn with the same color. Without + # this, ax.plot draws each column as a separate line with a new color + # from the cycle, causing duplicate hue values to get different colors. + # The 2D array can be xplt_val or yplt_val depending on orientation. + _has_2d_data = xplt_val.ndim == 2 or yplt_val.ndim == 2 + if hueplt is not None and _has_2d_data: + import matplotlib as mpl + + hue_values = hueplt.to_numpy() + # Unique values in order of first appearance + unique_hues = list(dict.fromkeys(hue_values)) + + # Only assign per-hue colors if the caller didn't set one explicitly + user_set_color = "color" in kwargs or "c" in kwargs + if not user_set_color: + prop_cycle = mpl.rcParams["axes.prop_cycle"] + cycle_colors = [p["color"] for p in prop_cycle] + + primitive = [] + legend_handles = [] + for idx, hue_val in enumerate(unique_hues): + plot_kwargs = dict(kwargs) + if not user_set_color: + plot_kwargs["color"] = cycle_colors[idx % len(cycle_colors)] + # Find all columns belonging to this hue group + col_indices = [i for i, h in enumerate(hue_values) if h == hue_val] + for col_idx in col_indices: + x_vals = xplt_val[:, col_idx] if xplt_val.ndim == 2 else xplt_val + y_vals = yplt_val[:, col_idx] if yplt_val.ndim == 2 else yplt_val + lines = ax.plot(x_vals, y_vals, *args, **plot_kwargs) + primitive.extend(lines) + # Keep only the first line per group for the legend + if col_idx == col_indices[0]: + legend_handles.extend(lines) + else: + primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) if _labels: if xlabel is not None: @@ -528,7 +564,13 @@ def line( if darray.ndim == 2 and add_legend: assert hueplt is not None - ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) + hue_values = hueplt.to_numpy() + unique_hues = list(dict.fromkeys(hue_values)) + ax.legend( + handles=legend_handles, + labels=[str(h) for h in unique_hues], + title=hue_label, + ) if isinstance(xplt.dtype, np.dtype) and np.issubdtype(xplt.dtype, np.datetime64): # type: ignore[redundant-expr] _set_concise_date(ax, axis="x")