Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
- Improve compatibility with AutoGluon/custom wrappers by coercing pandas `DataFrame` outputs from `predict_proba`/`predict` to numpy arrays before indexing in classifier/regression helper paths.
- Harden one-vs-all scorer handling so `make_one_vs_all_scorer` also accepts classifiers whose `predict_proba` returns a pandas `DataFrame`.
- Fix ExplainerHub `add_dashboard_route` after first request by allowing dynamic dashboard registration/setup during route-triggered add, and add a regression test for issue #269.
- Fix issue #273: avoid crashes when sorting categorical values containing mixed types and NaNs across explainer setup, PDP/category ordering, and categorical plot paths.
- Add regression tests for mixed-type categorical sorting and mixed target-label sorting fallbacks in explainers and plots.

## Version 0.5.6:

Expand Down
3 changes: 1 addition & 2 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
- Rules: link an issue when possible; include size S/M/L; mark blockers.

**Now**
- [M][Explainers][#273] categorical columns with NaNs: sorting and column preservation.
- [M/L][Components][#262] add filters for random transaction selection in whatif tab.

**Next**
- [M/L][Components][#262] add filters for random transaction selection in whatif tab.
- [S][Methods][#220] get_contrib_df accepts list/array input.
- [M][Components][#176] FeatureInputComponent hide parameters.
- [M][Explainers][#198] LightGBM string categorical handling across SHAP/plots.
Expand Down
33 changes: 31 additions & 2 deletions explainerdashboard/explainer_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"safe_isinstance",
"unwrap_calibrated_classifier",
"align_categorical_dtypes",
"sorted_categorical_values",
"guess_shap",
"mape_score",
"parse_cats",
Expand Down Expand Up @@ -341,6 +342,32 @@ def align_categorical_dtypes(
return aligned


def sorted_categorical_values(values):
"""Sort categorical values safely when types are mixed.

Keeps original values but orders deterministically:
booleans, then numbers, then other values by type/value string.
"""

def _is_na(value):
try:
return bool(pd.isna(value))
except Exception:
return False

def _sort_key(value):
if isinstance(value, np.generic):
value = value.item()
if isinstance(value, bool):
return (0, int(value))
if isinstance(value, (int, float)) and not _is_na(value):
return (1, float(value))
return (2, type(value).__name__, str(value))

clean_values = [value for value in values if not _is_na(value)]
return sorted(clean_values, key=_sort_key)


def guess_shap(model):
"""guesses which SHAP explainer to use for a particular model, based
on str(type(model)). Returns 'tree' for tree based models such as
Expand Down Expand Up @@ -1171,7 +1198,9 @@ def _model_input(data):
if grid_values is None:
if isinstance(feature, str):
if not is_numeric_dtype(X_sample[feature]):
grid_values = sorted(X_sample[feature].unique().tolist())
grid_values = sorted_categorical_values(
X_sample[feature].unique().tolist()
)
else:
grid_values = get_grid_points(
X_sample[feature],
Expand Down Expand Up @@ -1752,7 +1781,7 @@ def node_pred_proba(node):
return class_counts[pos_label] / total

# If pos_label not found, try to map it to available class keys
available_classes = sorted(class_counts.keys())
available_classes = sorted_categorical_values(class_counts.keys())
if len(available_classes) == 0:
return 0.0

Expand Down
20 changes: 14 additions & 6 deletions explainerdashboard/explainer_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from .explainer_methods import matching_cols, safe_isinstance
from .explainer_methods import (
matching_cols,
safe_isinstance,
sorted_categorical_values,
)


def plotly_prediction_piechart(predictions_df, showlegend=True, size=250):
Expand Down Expand Up @@ -1040,7 +1044,7 @@ def plotly_shap_violin_plot(
), f"{X_col.name} is not categorical! Can only plot violin plots for categorical features!"

if cats_order is None:
cats_order = sorted(X_col.unique().tolist())
cats_order = sorted_categorical_values(X_col.unique().tolist())

n_cats = len(cats_order)

Expand Down Expand Up @@ -1098,7 +1102,11 @@ def plotly_shap_violin_plot(

for i, cat in enumerate(cats_order):
col = 1 + i * 2 if points or X_color_col is not None else 1 + i
if cat.startswith(X_col.name + "_"):
if (
isinstance(cat, str)
and isinstance(X_col.name, str)
and cat.startswith(X_col.name + "_")
):
cat_name = cat[len(X_col.name) + 1 :]
else:
cat_name = cat
Expand Down Expand Up @@ -2347,7 +2355,7 @@ def plotly_residuals_vs_col(

if not is_numeric_dtype(col):
if cats_order is None:
cats_order = sorted(col.unique().tolist())
cats_order = sorted_categorical_values(col.unique().tolist())
n_cats = len(cats_order)

if points:
Expand Down Expand Up @@ -2506,7 +2514,7 @@ def plotly_actual_vs_col(

if not is_numeric_dtype(col):
if cats_order is None:
cats_order = sorted(col.unique().tolist())
cats_order = sorted_categorical_values(col.unique().tolist())
n_cats = len(cats_order)

if points:
Expand Down Expand Up @@ -2658,7 +2666,7 @@ def plotly_preds_vs_col(

if not is_numeric_dtype(col):
if cats_order is None:
cats_order = sorted(col.unique().tolist())
cats_order = sorted_categorical_values(col.unique().tolist())
n_cats = len(cats_order)

if points:
Expand Down
8 changes: 5 additions & 3 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(
col for col in self.regular_cols if not is_numeric_dtype(self.X[col])
]
self.categorical_dict = {
col: sorted(self.X[col].dropna().unique().tolist())
col: sorted_categorical_values(self.X[col].unique().tolist())
for col in self.categorical_cols
}
# Include an explicit NaN option for categorical columns with missing values.
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def ordered_cats(self, col, topx=None, sort="alphabet", pos_label=None):
values = self.categorical_dict[col]
else:
values = [v for v in X[col].unique().tolist() if not pd.isna(v)]
values = sorted(values)
values = sorted_categorical_values(values)
if topx is None:
return values
else:
Expand Down Expand Up @@ -2695,7 +2695,9 @@ def __init__(
if hasattr(self.model, "classes_"):
class_values = list(self.model.classes_)
elif not self.y_missing:
class_values = sorted(pd.Series(self.y).dropna().unique().tolist())
class_values = sorted_categorical_values(
pd.Series(self.y).dropna().unique().tolist()
)

if not self.y_missing:
if class_values is None:
Expand Down
78 changes: 78 additions & 0 deletions tests/test_categorical_sorting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import pandas as pd
import pytest

from explainerdashboard import ClassifierExplainer
from explainerdashboard.explainer_plots import (
plotly_actual_vs_col,
plotly_preds_vs_col,
plotly_residuals_vs_col,
plotly_shap_violin_plot,
)


class NoClassesModel:
"""Classifier-like model without classes_ attribute."""

def predict_proba(self, X):
n = len(X)
p = np.full(n, 0.6)
return np.c_[1 - p, p]

def predict(self, X):
return (self.predict_proba(X)[:, 1] > 0.5).astype(int)


@pytest.mark.parametrize(
"plot_fn,args",
[
(
plotly_shap_violin_plot,
(
pd.Series(["a", 1, np.nan, "b"], name="mixed_cat", dtype=object),
np.array([0.1, -0.2, 0.0, 0.3]),
),
),
(
plotly_residuals_vs_col,
(
np.array([1.0, 2.0, 3.0, 4.0]),
np.array([1.1, 1.9, 2.8, 4.2]),
pd.Series(["a", 1, np.nan, "b"], name="mixed_cat", dtype=object),
),
),
(
plotly_actual_vs_col,
(
np.array([1.0, 2.0, 3.0, 4.0]),
np.array([1.1, 1.9, 2.8, 4.2]),
pd.Series(["a", 1, np.nan, "b"], name="mixed_cat", dtype=object),
),
),
(
plotly_preds_vs_col,
(
np.array([1.0, 2.0, 3.0, 4.0]),
np.array([1.1, 1.9, 2.8, 4.2]),
pd.Series(["a", 1, np.nan, "b"], name="mixed_cat", dtype=object),
),
),
],
)
def test_plot_functions_do_not_crash_on_mixed_categorical_values(plot_fn, args):
try:
plot_fn(*args)
except TypeError as exc:
pytest.fail(f"Mixed categorical sorting should not crash: {exc}")


def test_classifier_explainer_without_classes_handles_mixed_target_labels():
X = pd.DataFrame({"a": [1, 2, 3, 4], "b": [0, 1, 0, 1]})
y = pd.Series([0, "1", 0, "1"], dtype=object)

try:
explainer = ClassifierExplainer(NoClassesModel(), X, y, shap="kernel")
except TypeError as exc:
pytest.fail(f"Mixed target-label sorting should not crash: {exc}")

assert len(explainer.labels) == 2
52 changes: 52 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ def predict_proba(self, X):
return self._model.predict_proba(X)


class MixedCategoricalModelWrapper:
def __init__(self, model) -> None:
self._model = model

def _preprocessor(self, X):
return X.drop(["MixedCat"], axis=1)

def predict(self, X):
X = self._preprocessor(X)
return self._model.predict(X)

def predict_proba(self, X):
X = self._preprocessor(X)
return self._model.predict_proba(X)


@pytest.fixture(scope="module")
def categorical_nan_explainer(classifier_data):
X_train, y_train, X_test, y_test = classifier_data
Expand Down Expand Up @@ -73,3 +89,39 @@ def test_get_row_from_input_preserves_nan_unmerged(categorical_nan_explainer):
row["Name"] = "NaN"
input_row = explainer.get_row_from_input(row.values.tolist())
assert pd.isna(input_row.loc[0, "Name"])


def test_mixed_type_categorical_with_nan_does_not_crash_sorting(classifier_data):
X_train, y_train, X_test, y_test = classifier_data
X_train = X_train.copy()
X_test = X_test.copy()

X_train["MixedCat"] = pd.Series(
["alpha"] * len(X_train), index=X_train.index, dtype=object
)
X_test["MixedCat"] = pd.Series(
["alpha"] * len(X_test), index=X_test.index, dtype=object
)

# Simulate mixed-type categorical values from upstream preprocessing artifacts.
X_train.loc[X_train.index[:8], "MixedCat"] = 1
X_train.loc[X_train.index[8:16], "MixedCat"] = np.nan
X_test.loc[X_test.index[:4], "MixedCat"] = 0
X_test.loc[X_test.index[4:8], "MixedCat"] = np.nan

model = RandomForestClassifier(n_estimators=5, max_depth=2, random_state=42)
model.fit(X_train.drop(["MixedCat"], axis=1), y_train)
wrapper = MixedCategoricalModelWrapper(model)

try:
ClassifierExplainer(
wrapper,
X_test,
y_test,
cats=["Sex", "Deck", "Embarked"],
shap="kernel",
)
except TypeError as exc:
pytest.fail(
f"Categorical sorting should not crash on mixed values + NaN: {exc}"
)
Loading