diff --git a/shapash/explainer/smart_plotter.py b/shapash/explainer/smart_plotter.py index 69d555af..d68fdad8 100644 --- a/shapash/explainer/smart_plotter.py +++ b/shapash/explainer/smart_plotter.py @@ -1876,6 +1876,8 @@ def confusion_matrix_plot( self, width: int = 700, height: int = 500, + quantile: float = 0.95, + use_quantile_scale: bool = True, file_name=None, auto_open=False, ): @@ -1895,6 +1897,10 @@ def confusion_matrix_plot( The width of the generated figure, in inches. height : int, optional, default=4 The height of the generated figure, in inches. + quantile : float, optional, default=0.95 + Upper quantile used to cap the heatmap color scale. + use_quantile_scale : bool, optional, default=True + Enable quantile-based color scaling for better contrast on imbalanced confusion matrices. Returns ------- @@ -1918,6 +1924,8 @@ def confusion_matrix_plot( colors_dict=self._style_dict, width=width, height=height, + quantile=quantile, + use_quantile_scale=use_quantile_scale, file_name=file_name, auto_open=auto_open, ) diff --git a/shapash/plots/plot_evaluation_metrics.py b/shapash/plots/plot_evaluation_metrics.py index b0b62343..39746143 100644 --- a/shapash/plots/plot_evaluation_metrics.py +++ b/shapash/plots/plot_evaluation_metrics.py @@ -503,6 +503,8 @@ def plot_confusion_matrix( width: int = 700, height: int = 500, palette_name: str = "default", + quantile: float = 0.95, + use_quantile_scale: bool = True, file_name=None, auto_open=False, ) -> go.Figure: @@ -523,6 +525,12 @@ def plot_confusion_matrix( The height of the figure in pixels. palette_name : str, optional The color palette to use for the heatmap. + quantile : float, optional + Upper quantile used to cap the heatmap color scale when `use_quantile_scale=True`. + Values above this threshold keep their displayed count but are saturated for color rendering. + use_quantile_scale : bool, optional + If True: use a quantile-based upper bound for the heatmap colors to improve contrast on + imbalanced confusion matrices. If False: use the full value range. file_name: string, optional Specify the save path of html files. If None, no file will be saved. auto_open: bool, optional @@ -557,6 +565,14 @@ def plot_confusion_matrix( y_labels = list(df_cm.index) z = df_cm.loc[x_labels, y_labels].values + if use_quantile_scale: + if not 0 < quantile <= 1: + raise ValueError("quantile must be in the interval ]0, 1].") + # Clip extreme values for a more readable heatmap. + zmax = int(np.quantile(z, quantile)) + else: + zmax = int(z.max()) + title = "Confusion Matrix" dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)} dict_xaxis = style_dict["dict_xaxis"] | {"text": se_y_pred.name} @@ -595,6 +611,8 @@ def plot_confusion_matrix( x=x_labels, y=y_labels, colorscale=col_scale, + zmin=0, + zmax=zmax, hovertext=hv_text, hovertemplate="%{hovertext}", showscale=True, @@ -612,7 +630,8 @@ def plot_confusion_matrix( y=y_label, text=str(z[i][j]), showarrow=False, - font=dict(color="black" if z[i][j] < z.max() / 2 else "white"), + # Use clipped values to keep the text contrast readable. + font=dict(color="black" if min(z[i][j], zmax) < zmax / 2 else "white"), ) ) diff --git a/tests/unit_tests/report/test_plots.py b/tests/unit_tests/report/test_plots.py index 213579ed..9526deb4 100644 --- a/tests/unit_tests/report/test_plots.py +++ b/tests/unit_tests/report/test_plots.py @@ -9,6 +9,7 @@ plot_categorical_distribution, plot_continuous_distribution, ) +from shapash.plots.plot_evaluation_metrics import plot_confusion_matrix from plotly import graph_objects as go @@ -227,3 +228,27 @@ def test_plot_categorical_distribution_6(self): assert fig.data[0].type == "bar" assert fig.data[1].type == "bar" assert len(fig.data[0]['x']) == 8 + + def test_plot_confusion_matrix_with_quantile_scale(self): + y_true = ["c01"] * 2 + ["c02"] * 4 + ["c03"] * 103 + y_pred = ["c01"] * 2 + ["c01"] + ["c02"] * 3 + ["c01"] + ["c03"] * 102 + + fig = plot_confusion_matrix(y_true=y_true, y_pred=y_pred, quantile=0.95, use_quantile_scale=True) + + expected_matrix = np.array([[2, 0, 0], [1, 3, 0], [1, 0, 102]]) + assert np.array_equal(np.array(fig.data[0].z), expected_matrix) + assert fig.data[0].zmin == 0 + assert fig.data[0].zmax == int(np.quantile(expected_matrix, 0.95)) + assert fig.layout.annotations[8].text == "102" + + def test_plot_confusion_matrix_without_quantile_scale(self): + y_true = ["c01"] * 2 + ["c02"] * 4 + ["c03"] * 103 + y_pred = ["c01"] * 2 + ["c01"] + ["c02"] * 3 + ["c01"] + ["c03"] * 102 + + fig = plot_confusion_matrix(y_true=y_true, y_pred=y_pred, use_quantile_scale=False) + + expected_matrix = np.array([[2, 0, 0], [1, 3, 0], [1, 0, 102]]) + assert np.array_equal(np.array(fig.data[0].z), expected_matrix) + assert fig.data[0].zmin == 0 + assert fig.data[0].zmax == 102 + assert fig.layout.annotations[8].text == "102"