diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 0ce4f7e..fbf599e 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -8,6 +8,8 @@ - Improve compatibility with AutoGluon/custom wrappers by coercing pandas `DataFrame` outputs from `predict_proba`/`predict` to numpy arrays before indexing in classifier/regression helper paths. - Harden one-vs-all scorer handling so `make_one_vs_all_scorer` also accepts classifiers whose `predict_proba` returns a pandas `DataFrame`. - Fix ExplainerHub `add_dashboard_route` after first request by allowing dynamic dashboard registration/setup during route-triggered add, and add a regression test for issue #269. +- Fix issue #273: avoid crashes when sorting categorical values containing mixed types and NaNs across explainer setup, PDP/category ordering, and categorical plot paths. +- Add regression tests for mixed-type categorical sorting and mixed target-label sorting fallbacks in explainers and plots. ## Version 0.5.6: diff --git a/TODO.md b/TODO.md index e5ac218..93e0b18 100644 --- a/TODO.md +++ b/TODO.md @@ -6,10 +6,9 @@ - Rules: link an issue when possible; include size S/M/L; mark blockers. **Now** -- [M][Explainers][#273] categorical columns with NaNs: sorting and column preservation. +- [M/L][Components][#262] add filters for random transaction selection in whatif tab. **Next** -- [M/L][Components][#262] add filters for random transaction selection in whatif tab. - [S][Methods][#220] get_contrib_df accepts list/array input. - [M][Components][#176] FeatureInputComponent hide parameters. - [M][Explainers][#198] LightGBM string categorical handling across SHAP/plots. diff --git a/explainerdashboard/explainer_methods.py b/explainerdashboard/explainer_methods.py index 7bb574e..a332280 100644 --- a/explainerdashboard/explainer_methods.py +++ b/explainerdashboard/explainer_methods.py @@ -4,6 +4,7 @@ "safe_isinstance", "unwrap_calibrated_classifier", "align_categorical_dtypes", + "sorted_categorical_values", "guess_shap", "mape_score", "parse_cats", @@ -341,6 +342,32 @@ def align_categorical_dtypes( return aligned +def sorted_categorical_values(values): + """Sort categorical values safely when types are mixed. + + Keeps original values but orders deterministically: + booleans, then numbers, then other values by type/value string. + """ + + def _is_na(value): + try: + return bool(pd.isna(value)) + except Exception: + return False + + def _sort_key(value): + if isinstance(value, np.generic): + value = value.item() + if isinstance(value, bool): + return (0, int(value)) + if isinstance(value, (int, float)) and not _is_na(value): + return (1, float(value)) + return (2, type(value).__name__, str(value)) + + clean_values = [value for value in values if not _is_na(value)] + return sorted(clean_values, key=_sort_key) + + def guess_shap(model): """guesses which SHAP explainer to use for a particular model, based on str(type(model)). Returns 'tree' for tree based models such as @@ -1171,7 +1198,9 @@ def _model_input(data): if grid_values is None: if isinstance(feature, str): if not is_numeric_dtype(X_sample[feature]): - grid_values = sorted(X_sample[feature].unique().tolist()) + grid_values = sorted_categorical_values( + X_sample[feature].unique().tolist() + ) else: grid_values = get_grid_points( X_sample[feature], @@ -1752,7 +1781,7 @@ def node_pred_proba(node): return class_counts[pos_label] / total # If pos_label not found, try to map it to available class keys - available_classes = sorted(class_counts.keys()) + available_classes = sorted_categorical_values(class_counts.keys()) if len(available_classes) == 0: return 0.0 diff --git a/explainerdashboard/explainer_plots.py b/explainerdashboard/explainer_plots.py index b8ea594..e97cbc3 100644 --- a/explainerdashboard/explainer_plots.py +++ b/explainerdashboard/explainer_plots.py @@ -31,7 +31,11 @@ import plotly.graph_objects as go from plotly.subplots import make_subplots -from .explainer_methods import matching_cols, safe_isinstance +from .explainer_methods import ( + matching_cols, + safe_isinstance, + sorted_categorical_values, +) def plotly_prediction_piechart(predictions_df, showlegend=True, size=250): @@ -1040,7 +1044,7 @@ def plotly_shap_violin_plot( ), f"{X_col.name} is not categorical! Can only plot violin plots for categorical features!" if cats_order is None: - cats_order = sorted(X_col.unique().tolist()) + cats_order = sorted_categorical_values(X_col.unique().tolist()) n_cats = len(cats_order) @@ -1098,7 +1102,11 @@ def plotly_shap_violin_plot( for i, cat in enumerate(cats_order): col = 1 + i * 2 if points or X_color_col is not None else 1 + i - if cat.startswith(X_col.name + "_"): + if ( + isinstance(cat, str) + and isinstance(X_col.name, str) + and cat.startswith(X_col.name + "_") + ): cat_name = cat[len(X_col.name) + 1 :] else: cat_name = cat @@ -2347,7 +2355,7 @@ def plotly_residuals_vs_col( if not is_numeric_dtype(col): if cats_order is None: - cats_order = sorted(col.unique().tolist()) + cats_order = sorted_categorical_values(col.unique().tolist()) n_cats = len(cats_order) if points: @@ -2506,7 +2514,7 @@ def plotly_actual_vs_col( if not is_numeric_dtype(col): if cats_order is None: - cats_order = sorted(col.unique().tolist()) + cats_order = sorted_categorical_values(col.unique().tolist()) n_cats = len(cats_order) if points: @@ -2658,7 +2666,7 @@ def plotly_preds_vs_col( if not is_numeric_dtype(col): if cats_order is None: - cats_order = sorted(col.unique().tolist()) + cats_order = sorted_categorical_values(col.unique().tolist()) n_cats = len(cats_order) if points: diff --git a/explainerdashboard/explainers.py b/explainerdashboard/explainers.py index 5d3eb92..2c5691c 100644 --- a/explainerdashboard/explainers.py +++ b/explainerdashboard/explainers.py @@ -275,7 +275,7 @@ def __init__( col for col in self.regular_cols if not is_numeric_dtype(self.X[col]) ] self.categorical_dict = { - col: sorted(self.X[col].dropna().unique().tolist()) + col: sorted_categorical_values(self.X[col].unique().tolist()) for col in self.categorical_cols } # Include an explicit NaN option for categorical columns with missing values. @@ -1010,7 +1010,7 @@ def ordered_cats(self, col, topx=None, sort="alphabet", pos_label=None): values = self.categorical_dict[col] else: values = [v for v in X[col].unique().tolist() if not pd.isna(v)] - values = sorted(values) + values = sorted_categorical_values(values) if topx is None: return values else: @@ -2695,7 +2695,9 @@ def __init__( if hasattr(self.model, "classes_"): class_values = list(self.model.classes_) elif not self.y_missing: - class_values = sorted(pd.Series(self.y).dropna().unique().tolist()) + class_values = sorted_categorical_values( + pd.Series(self.y).dropna().unique().tolist() + ) if not self.y_missing: if class_values is None: diff --git a/tests/test_categorical_sorting.py b/tests/test_categorical_sorting.py new file mode 100644 index 0000000..e621295 --- /dev/null +++ b/tests/test_categorical_sorting.py @@ -0,0 +1,78 @@ +import numpy as np +import pandas as pd +import pytest + +from explainerdashboard import ClassifierExplainer +from explainerdashboard.explainer_plots import ( + plotly_actual_vs_col, + plotly_preds_vs_col, + plotly_residuals_vs_col, + plotly_shap_violin_plot, +) + + +class NoClassesModel: + """Classifier-like model without classes_ attribute.""" + + def predict_proba(self, X): + n = len(X) + p = np.full(n, 0.6) + return np.c_[1 - p, p] + + def predict(self, X): + return (self.predict_proba(X)[:, 1] > 0.5).astype(int) + + +@pytest.mark.parametrize( + "plot_fn,args", + [ + ( + plotly_shap_violin_plot, + ( + pd.Series(["a", 1, np.nan, "b"], name="mixed_cat", dtype=object), + np.array([0.1, -0.2, 0.0, 0.3]), + ), + ), + ( + plotly_residuals_vs_col, + ( + np.array([1.0, 2.0, 3.0, 4.0]), + np.array([1.1, 1.9, 2.8, 4.2]), + pd.Series(["a", 1, np.nan, "b"], name="mixed_cat", dtype=object), + ), + ), + ( + plotly_actual_vs_col, + ( + np.array([1.0, 2.0, 3.0, 4.0]), + np.array([1.1, 1.9, 2.8, 4.2]), + pd.Series(["a", 1, np.nan, "b"], name="mixed_cat", dtype=object), + ), + ), + ( + plotly_preds_vs_col, + ( + np.array([1.0, 2.0, 3.0, 4.0]), + np.array([1.1, 1.9, 2.8, 4.2]), + pd.Series(["a", 1, np.nan, "b"], name="mixed_cat", dtype=object), + ), + ), + ], +) +def test_plot_functions_do_not_crash_on_mixed_categorical_values(plot_fn, args): + try: + plot_fn(*args) + except TypeError as exc: + pytest.fail(f"Mixed categorical sorting should not crash: {exc}") + + +def test_classifier_explainer_without_classes_handles_mixed_target_labels(): + X = pd.DataFrame({"a": [1, 2, 3, 4], "b": [0, 1, 0, 1]}) + y = pd.Series([0, "1", 0, "1"], dtype=object) + + try: + explainer = ClassifierExplainer(NoClassesModel(), X, y, shap="kernel") + except TypeError as exc: + pytest.fail(f"Mixed target-label sorting should not crash: {exc}") + + assert len(explainer.labels) == 2 diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e805d03..87c2f8a 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -22,6 +22,22 @@ def predict_proba(self, X): return self._model.predict_proba(X) +class MixedCategoricalModelWrapper: + def __init__(self, model) -> None: + self._model = model + + def _preprocessor(self, X): + return X.drop(["MixedCat"], axis=1) + + def predict(self, X): + X = self._preprocessor(X) + return self._model.predict(X) + + def predict_proba(self, X): + X = self._preprocessor(X) + return self._model.predict_proba(X) + + @pytest.fixture(scope="module") def categorical_nan_explainer(classifier_data): X_train, y_train, X_test, y_test = classifier_data @@ -73,3 +89,39 @@ def test_get_row_from_input_preserves_nan_unmerged(categorical_nan_explainer): row["Name"] = "NaN" input_row = explainer.get_row_from_input(row.values.tolist()) assert pd.isna(input_row.loc[0, "Name"]) + + +def test_mixed_type_categorical_with_nan_does_not_crash_sorting(classifier_data): + X_train, y_train, X_test, y_test = classifier_data + X_train = X_train.copy() + X_test = X_test.copy() + + X_train["MixedCat"] = pd.Series( + ["alpha"] * len(X_train), index=X_train.index, dtype=object + ) + X_test["MixedCat"] = pd.Series( + ["alpha"] * len(X_test), index=X_test.index, dtype=object + ) + + # Simulate mixed-type categorical values from upstream preprocessing artifacts. + X_train.loc[X_train.index[:8], "MixedCat"] = 1 + X_train.loc[X_train.index[8:16], "MixedCat"] = np.nan + X_test.loc[X_test.index[:4], "MixedCat"] = 0 + X_test.loc[X_test.index[4:8], "MixedCat"] = np.nan + + model = RandomForestClassifier(n_estimators=5, max_depth=2, random_state=42) + model.fit(X_train.drop(["MixedCat"], axis=1), y_train) + wrapper = MixedCategoricalModelWrapper(model) + + try: + ClassifierExplainer( + wrapper, + X_test, + y_test, + cats=["Sex", "Deck", "Embarked"], + shap="kernel", + ) + except TypeError as exc: + pytest.fail( + f"Categorical sorting should not crash on mixed values + NaN: {exc}" + )