Skip to content

Commit b0bc508

Browse files
authored
Fix mixed categorical sorting crashes (issue #273) (#336)
1 parent e3566aa commit b0bc508

7 files changed

Lines changed: 183 additions & 13 deletions

File tree

RELEASE_NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
- Improve compatibility with AutoGluon/custom wrappers by coercing pandas `DataFrame` outputs from `predict_proba`/`predict` to numpy arrays before indexing in classifier/regression helper paths.
99
- Harden one-vs-all scorer handling so `make_one_vs_all_scorer` also accepts classifiers whose `predict_proba` returns a pandas `DataFrame`.
1010
- Fix ExplainerHub `add_dashboard_route` after first request by allowing dynamic dashboard registration/setup during route-triggered add, and add a regression test for issue #269.
11+
- Fix issue #273: avoid crashes when sorting categorical values containing mixed types and NaNs across explainer setup, PDP/category ordering, and categorical plot paths.
12+
- Add regression tests for mixed-type categorical sorting and mixed target-label sorting fallbacks in explainers and plots.
1113

1214
## Version 0.5.6:
1315

TODO.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
- Rules: link an issue when possible; include size S/M/L; mark blockers.
77

88
**Now**
9-
- [M][Explainers][#273] categorical columns with NaNs: sorting and column preservation.
9+
- [M/L][Components][#262] add filters for random transaction selection in whatif tab.
1010

1111
**Next**
12-
- [M/L][Components][#262] add filters for random transaction selection in whatif tab.
1312
- [S][Methods][#220] get_contrib_df accepts list/array input.
1413
- [M][Components][#176] FeatureInputComponent hide parameters.
1514
- [M][Explainers][#198] LightGBM string categorical handling across SHAP/plots.

explainerdashboard/explainer_methods.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"safe_isinstance",
55
"unwrap_calibrated_classifier",
66
"align_categorical_dtypes",
7+
"sorted_categorical_values",
78
"guess_shap",
89
"mape_score",
910
"parse_cats",
@@ -341,6 +342,32 @@ def align_categorical_dtypes(
341342
return aligned
342343

343344

345+
def sorted_categorical_values(values):
346+
"""Sort categorical values safely when types are mixed.
347+
348+
Keeps original values but orders deterministically:
349+
booleans, then numbers, then other values by type/value string.
350+
"""
351+
352+
def _is_na(value):
353+
try:
354+
return bool(pd.isna(value))
355+
except Exception:
356+
return False
357+
358+
def _sort_key(value):
359+
if isinstance(value, np.generic):
360+
value = value.item()
361+
if isinstance(value, bool):
362+
return (0, int(value))
363+
if isinstance(value, (int, float)) and not _is_na(value):
364+
return (1, float(value))
365+
return (2, type(value).__name__, str(value))
366+
367+
clean_values = [value for value in values if not _is_na(value)]
368+
return sorted(clean_values, key=_sort_key)
369+
370+
344371
def guess_shap(model):
345372
"""guesses which SHAP explainer to use for a particular model, based
346373
on str(type(model)). Returns 'tree' for tree based models such as
@@ -1171,7 +1198,9 @@ def _model_input(data):
11711198
if grid_values is None:
11721199
if isinstance(feature, str):
11731200
if not is_numeric_dtype(X_sample[feature]):
1174-
grid_values = sorted(X_sample[feature].unique().tolist())
1201+
grid_values = sorted_categorical_values(
1202+
X_sample[feature].unique().tolist()
1203+
)
11751204
else:
11761205
grid_values = get_grid_points(
11771206
X_sample[feature],
@@ -1752,7 +1781,7 @@ def node_pred_proba(node):
17521781
return class_counts[pos_label] / total
17531782

17541783
# If pos_label not found, try to map it to available class keys
1755-
available_classes = sorted(class_counts.keys())
1784+
available_classes = sorted_categorical_values(class_counts.keys())
17561785
if len(available_classes) == 0:
17571786
return 0.0
17581787

explainerdashboard/explainer_plots.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@
3131
import plotly.graph_objects as go
3232
from plotly.subplots import make_subplots
3333

34-
from .explainer_methods import matching_cols, safe_isinstance
34+
from .explainer_methods import (
35+
matching_cols,
36+
safe_isinstance,
37+
sorted_categorical_values,
38+
)
3539

3640

3741
def plotly_prediction_piechart(predictions_df, showlegend=True, size=250):
@@ -1040,7 +1044,7 @@ def plotly_shap_violin_plot(
10401044
), f"{X_col.name} is not categorical! Can only plot violin plots for categorical features!"
10411045

10421046
if cats_order is None:
1043-
cats_order = sorted(X_col.unique().tolist())
1047+
cats_order = sorted_categorical_values(X_col.unique().tolist())
10441048

10451049
n_cats = len(cats_order)
10461050

@@ -1098,7 +1102,11 @@ def plotly_shap_violin_plot(
10981102

10991103
for i, cat in enumerate(cats_order):
11001104
col = 1 + i * 2 if points or X_color_col is not None else 1 + i
1101-
if cat.startswith(X_col.name + "_"):
1105+
if (
1106+
isinstance(cat, str)
1107+
and isinstance(X_col.name, str)
1108+
and cat.startswith(X_col.name + "_")
1109+
):
11021110
cat_name = cat[len(X_col.name) + 1 :]
11031111
else:
11041112
cat_name = cat
@@ -2347,7 +2355,7 @@ def plotly_residuals_vs_col(
23472355

23482356
if not is_numeric_dtype(col):
23492357
if cats_order is None:
2350-
cats_order = sorted(col.unique().tolist())
2358+
cats_order = sorted_categorical_values(col.unique().tolist())
23512359
n_cats = len(cats_order)
23522360

23532361
if points:
@@ -2506,7 +2514,7 @@ def plotly_actual_vs_col(
25062514

25072515
if not is_numeric_dtype(col):
25082516
if cats_order is None:
2509-
cats_order = sorted(col.unique().tolist())
2517+
cats_order = sorted_categorical_values(col.unique().tolist())
25102518
n_cats = len(cats_order)
25112519

25122520
if points:
@@ -2658,7 +2666,7 @@ def plotly_preds_vs_col(
26582666

26592667
if not is_numeric_dtype(col):
26602668
if cats_order is None:
2661-
cats_order = sorted(col.unique().tolist())
2669+
cats_order = sorted_categorical_values(col.unique().tolist())
26622670
n_cats = len(cats_order)
26632671

26642672
if points:

explainerdashboard/explainers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def __init__(
275275
col for col in self.regular_cols if not is_numeric_dtype(self.X[col])
276276
]
277277
self.categorical_dict = {
278-
col: sorted(self.X[col].dropna().unique().tolist())
278+
col: sorted_categorical_values(self.X[col].unique().tolist())
279279
for col in self.categorical_cols
280280
}
281281
# Include an explicit NaN option for categorical columns with missing values.
@@ -1010,7 +1010,7 @@ def ordered_cats(self, col, topx=None, sort="alphabet", pos_label=None):
10101010
values = self.categorical_dict[col]
10111011
else:
10121012
values = [v for v in X[col].unique().tolist() if not pd.isna(v)]
1013-
values = sorted(values)
1013+
values = sorted_categorical_values(values)
10141014
if topx is None:
10151015
return values
10161016
else:
@@ -2695,7 +2695,9 @@ def __init__(
26952695
if hasattr(self.model, "classes_"):
26962696
class_values = list(self.model.classes_)
26972697
elif not self.y_missing:
2698-
class_values = sorted(pd.Series(self.y).dropna().unique().tolist())
2698+
class_values = sorted_categorical_values(
2699+
pd.Series(self.y).dropna().unique().tolist()
2700+
)
26992701

27002702
if not self.y_missing:
27012703
if class_values is None:

tests/test_categorical_sorting.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
from explainerdashboard import ClassifierExplainer
6+
from explainerdashboard.explainer_plots import (
7+
plotly_actual_vs_col,
8+
plotly_preds_vs_col,
9+
plotly_residuals_vs_col,
10+
plotly_shap_violin_plot,
11+
)
12+
13+
14+
class NoClassesModel:
15+
"""Classifier-like model without classes_ attribute."""
16+
17+
def predict_proba(self, X):
18+
n = len(X)
19+
p = np.full(n, 0.6)
20+
return np.c_[1 - p, p]
21+
22+
def predict(self, X):
23+
return (self.predict_proba(X)[:, 1] > 0.5).astype(int)
24+
25+
26+
@pytest.mark.parametrize(
27+
"plot_fn,args",
28+
[
29+
(
30+
plotly_shap_violin_plot,
31+
(
32+
pd.Series(["a", 1, np.nan, "b"], name="mixed_cat", dtype=object),
33+
np.array([0.1, -0.2, 0.0, 0.3]),
34+
),
35+
),
36+
(
37+
plotly_residuals_vs_col,
38+
(
39+
np.array([1.0, 2.0, 3.0, 4.0]),
40+
np.array([1.1, 1.9, 2.8, 4.2]),
41+
pd.Series(["a", 1, np.nan, "b"], name="mixed_cat", dtype=object),
42+
),
43+
),
44+
(
45+
plotly_actual_vs_col,
46+
(
47+
np.array([1.0, 2.0, 3.0, 4.0]),
48+
np.array([1.1, 1.9, 2.8, 4.2]),
49+
pd.Series(["a", 1, np.nan, "b"], name="mixed_cat", dtype=object),
50+
),
51+
),
52+
(
53+
plotly_preds_vs_col,
54+
(
55+
np.array([1.0, 2.0, 3.0, 4.0]),
56+
np.array([1.1, 1.9, 2.8, 4.2]),
57+
pd.Series(["a", 1, np.nan, "b"], name="mixed_cat", dtype=object),
58+
),
59+
),
60+
],
61+
)
62+
def test_plot_functions_do_not_crash_on_mixed_categorical_values(plot_fn, args):
63+
try:
64+
plot_fn(*args)
65+
except TypeError as exc:
66+
pytest.fail(f"Mixed categorical sorting should not crash: {exc}")
67+
68+
69+
def test_classifier_explainer_without_classes_handles_mixed_target_labels():
70+
X = pd.DataFrame({"a": [1, 2, 3, 4], "b": [0, 1, 0, 1]})
71+
y = pd.Series([0, "1", 0, "1"], dtype=object)
72+
73+
try:
74+
explainer = ClassifierExplainer(NoClassesModel(), X, y, shap="kernel")
75+
except TypeError as exc:
76+
pytest.fail(f"Mixed target-label sorting should not crash: {exc}")
77+
78+
assert len(explainer.labels) == 2

tests/test_datasets.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,22 @@ def predict_proba(self, X):
2222
return self._model.predict_proba(X)
2323

2424

25+
class MixedCategoricalModelWrapper:
26+
def __init__(self, model) -> None:
27+
self._model = model
28+
29+
def _preprocessor(self, X):
30+
return X.drop(["MixedCat"], axis=1)
31+
32+
def predict(self, X):
33+
X = self._preprocessor(X)
34+
return self._model.predict(X)
35+
36+
def predict_proba(self, X):
37+
X = self._preprocessor(X)
38+
return self._model.predict_proba(X)
39+
40+
2541
@pytest.fixture(scope="module")
2642
def categorical_nan_explainer(classifier_data):
2743
X_train, y_train, X_test, y_test = classifier_data
@@ -73,3 +89,39 @@ def test_get_row_from_input_preserves_nan_unmerged(categorical_nan_explainer):
7389
row["Name"] = "NaN"
7490
input_row = explainer.get_row_from_input(row.values.tolist())
7591
assert pd.isna(input_row.loc[0, "Name"])
92+
93+
94+
def test_mixed_type_categorical_with_nan_does_not_crash_sorting(classifier_data):
95+
X_train, y_train, X_test, y_test = classifier_data
96+
X_train = X_train.copy()
97+
X_test = X_test.copy()
98+
99+
X_train["MixedCat"] = pd.Series(
100+
["alpha"] * len(X_train), index=X_train.index, dtype=object
101+
)
102+
X_test["MixedCat"] = pd.Series(
103+
["alpha"] * len(X_test), index=X_test.index, dtype=object
104+
)
105+
106+
# Simulate mixed-type categorical values from upstream preprocessing artifacts.
107+
X_train.loc[X_train.index[:8], "MixedCat"] = 1
108+
X_train.loc[X_train.index[8:16], "MixedCat"] = np.nan
109+
X_test.loc[X_test.index[:4], "MixedCat"] = 0
110+
X_test.loc[X_test.index[4:8], "MixedCat"] = np.nan
111+
112+
model = RandomForestClassifier(n_estimators=5, max_depth=2, random_state=42)
113+
model.fit(X_train.drop(["MixedCat"], axis=1), y_train)
114+
wrapper = MixedCategoricalModelWrapper(model)
115+
116+
try:
117+
ClassifierExplainer(
118+
wrapper,
119+
X_test,
120+
y_test,
121+
cats=["Sex", "Deck", "Embarked"],
122+
shap="kernel",
123+
)
124+
except TypeError as exc:
125+
pytest.fail(
126+
f"Categorical sorting should not crash on mixed values + NaN: {exc}"
127+
)

0 commit comments

Comments
 (0)