Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions shapash/explainer/smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,6 +1876,8 @@ def confusion_matrix_plot(
self,
width: int = 700,
height: int = 500,
quantile: float = 0.95,
use_quantile_scale: bool = True,
file_name=None,
auto_open=False,
):
Expand All @@ -1895,6 +1897,10 @@ def confusion_matrix_plot(
The width of the generated figure, in inches.
height : int, optional, default=4
The height of the generated figure, in inches.
quantile : float, optional, default=0.95
Upper quantile used to cap the heatmap color scale.
use_quantile_scale : bool, optional, default=True
Enable quantile-based color scaling for better contrast on imbalanced confusion matrices.

Returns
-------
Expand All @@ -1918,6 +1924,8 @@ def confusion_matrix_plot(
colors_dict=self._style_dict,
width=width,
height=height,
quantile=quantile,
use_quantile_scale=use_quantile_scale,
file_name=file_name,
auto_open=auto_open,
)
Expand Down
21 changes: 20 additions & 1 deletion shapash/plots/plot_evaluation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ def plot_confusion_matrix(
width: int = 700,
height: int = 500,
palette_name: str = "default",
quantile: float = 0.95,
use_quantile_scale: bool = True,
file_name=None,
auto_open=False,
) -> go.Figure:
Expand All @@ -523,6 +525,12 @@ def plot_confusion_matrix(
The height of the figure in pixels.
palette_name : str, optional
The color palette to use for the heatmap.
quantile : float, optional
Upper quantile used to cap the heatmap color scale when `use_quantile_scale=True`.
Values above this threshold keep their displayed count but are saturated for color rendering.
use_quantile_scale : bool, optional
If True: use a quantile-based upper bound for the heatmap colors to improve contrast on
imbalanced confusion matrices. If False: use the full value range.
file_name: string, optional
Specify the save path of html files. If None, no file will be saved.
auto_open: bool, optional
Expand Down Expand Up @@ -557,6 +565,14 @@ def plot_confusion_matrix(
y_labels = list(df_cm.index)
z = df_cm.loc[x_labels, y_labels].values

if use_quantile_scale:
if not 0 < quantile <= 1:
raise ValueError("quantile must be in the interval ]0, 1].")
# Clip extreme values for a more readable heatmap.
zmax = int(np.quantile(z, quantile))
else:
zmax = int(z.max())

title = "Confusion Matrix"
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = style_dict["dict_xaxis"] | {"text": se_y_pred.name}
Expand Down Expand Up @@ -595,6 +611,8 @@ def plot_confusion_matrix(
x=x_labels,
y=y_labels,
colorscale=col_scale,
zmin=0,
zmax=zmax,
hovertext=hv_text,
hovertemplate="%{hovertext}<extra></extra>",
showscale=True,
Expand All @@ -612,7 +630,8 @@ def plot_confusion_matrix(
y=y_label,
text=str(z[i][j]),
showarrow=False,
font=dict(color="black" if z[i][j] < z.max() / 2 else "white"),
# Use clipped values to keep the text contrast readable.
font=dict(color="black" if min(z[i][j], zmax) < zmax / 2 else "white"),
)
)

Expand Down
25 changes: 25 additions & 0 deletions tests/unit_tests/report/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
plot_categorical_distribution,
plot_continuous_distribution,
)
from shapash.plots.plot_evaluation_metrics import plot_confusion_matrix

from plotly import graph_objects as go

Expand Down Expand Up @@ -227,3 +228,27 @@ def test_plot_categorical_distribution_6(self):
assert fig.data[0].type == "bar"
assert fig.data[1].type == "bar"
assert len(fig.data[0]['x']) == 8

def test_plot_confusion_matrix_with_quantile_scale(self):
y_true = ["c01"] * 2 + ["c02"] * 4 + ["c03"] * 103
y_pred = ["c01"] * 2 + ["c01"] + ["c02"] * 3 + ["c01"] + ["c03"] * 102

fig = plot_confusion_matrix(y_true=y_true, y_pred=y_pred, quantile=0.95, use_quantile_scale=True)

expected_matrix = np.array([[2, 0, 0], [1, 3, 0], [1, 0, 102]])
assert np.array_equal(np.array(fig.data[0].z), expected_matrix)
assert fig.data[0].zmin == 0
assert fig.data[0].zmax == int(np.quantile(expected_matrix, 0.95))
assert fig.layout.annotations[8].text == "102"

def test_plot_confusion_matrix_without_quantile_scale(self):
y_true = ["c01"] * 2 + ["c02"] * 4 + ["c03"] * 103
y_pred = ["c01"] * 2 + ["c01"] + ["c02"] * 3 + ["c01"] + ["c03"] * 102

fig = plot_confusion_matrix(y_true=y_true, y_pred=y_pred, use_quantile_scale=False)

expected_matrix = np.array([[2, 0, 0], [1, 3, 0], [1, 0, 102]])
assert np.array_equal(np.array(fig.data[0].z), expected_matrix)
assert fig.data[0].zmin == 0
assert fig.data[0].zmax == 102
assert fig.layout.annotations[8].text == "102"