Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions xarray/plot/dataarray_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,43 @@ def line(

_ensure_plottable(xplt_val, yplt_val)

primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs)
# When plotting 2D data with hue, group lines by hue value so that
# lines sharing the same hue are drawn with the same color. Without
# this, ax.plot draws each column as a separate line with a new color
# from the cycle, causing duplicate hue values to get different colors.
# The 2D array can be xplt_val or yplt_val depending on orientation.
_has_2d_data = xplt_val.ndim == 2 or yplt_val.ndim == 2
if hueplt is not None and _has_2d_data:
import matplotlib as mpl

hue_values = hueplt.to_numpy()
# Unique values in order of first appearance
unique_hues = list(dict.fromkeys(hue_values))

# Only assign per-hue colors if the caller didn't set one explicitly
user_set_color = "color" in kwargs or "c" in kwargs
if not user_set_color:
prop_cycle = mpl.rcParams["axes.prop_cycle"]
cycle_colors = [p["color"] for p in prop_cycle]

primitive = []
legend_handles = []
for idx, hue_val in enumerate(unique_hues):
plot_kwargs = dict(kwargs)
if not user_set_color:
plot_kwargs["color"] = cycle_colors[idx % len(cycle_colors)]
# Find all columns belonging to this hue group
col_indices = [i for i, h in enumerate(hue_values) if h == hue_val]
for col_idx in col_indices:
x_vals = xplt_val[:, col_idx] if xplt_val.ndim == 2 else xplt_val
y_vals = yplt_val[:, col_idx] if yplt_val.ndim == 2 else yplt_val
lines = ax.plot(x_vals, y_vals, *args, **plot_kwargs)
primitive.extend(lines)
# Keep only the first line per group for the legend
if col_idx == col_indices[0]:
legend_handles.extend(lines)
else:
primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs)

if _labels:
if xlabel is not None:
Expand All @@ -528,7 +564,13 @@ def line(

if darray.ndim == 2 and add_legend:
assert hueplt is not None
ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)
hue_values = hueplt.to_numpy()
unique_hues = list(dict.fromkeys(hue_values))
ax.legend(
handles=legend_handles,
labels=[str(h) for h in unique_hues],
title=hue_label,
)

if isinstance(xplt.dtype, np.dtype) and np.issubdtype(xplt.dtype, np.datetime64): # type: ignore[redundant-expr]
_set_concise_date(ax, axis="x")
Expand Down
Loading