Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Fix FeatureInputComponent range calculation for boolean columns (avoid np.round on bools) and add a regression test.
- Ensure save_html includes custom tabs by providing a static-export fallback for tabs without a to_html implementation.
- Support string class labels in ClassifierExplainer by preserving label mappings and avoiding float casts.
- Support CalibratedClassifierCV by using its fitted base estimator for SHAP (avoids falling back to kernel).

### Improvements
- Replace print statements with standard logging and warnings; progress messages are now INFO-level and user-actionable guidance uses warnings. A one-time warning is emitted if logging is not configured, with instructions to call `enable_default_logging()`.
Expand Down
4 changes: 1 addition & 3 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
- Rules: link an issue when possible; include size S/M/L; mark blockers.

**Now**
- [M/L][Explainers][#279] support CalibratedClassifierCV (unwrap estimator for SHAP; update logic and tests).
- [S/M][Components][#277] whatif input range/rounding customization.

**Next**
- [S/M][Components][#277] whatif input range/rounding customization.
- [S/M][Explainers][#274] support string labels without float casts.
- [M][Explainers][#273] categorical columns with NaNs: sorting and column preservation.
- [S][Explainers][#270] Autogluon integration (coerce predict_proba to ndarray).
- [M][Hub][#269] add_dashboard endpoint fails after first request (Flask blueprint lifecycle).
Expand Down
24 changes: 24 additions & 0 deletions explainerdashboard/explainer_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"IndexNotFoundError",
"append_dict_to_df",
"safe_isinstance",
"unwrap_calibrated_classifier",
"align_categorical_dtypes",
"guess_shap",
"mape_score",
Expand Down Expand Up @@ -296,6 +297,27 @@ def safe_isinstance(obj, *instance_str):
return False


def unwrap_calibrated_classifier(model):
"""Return the fitted base estimator for a CalibratedClassifierCV model."""
if not safe_isinstance(model, "CalibratedClassifierCV"):
return model

calibrated_classifiers = getattr(model, "calibrated_classifiers_", None)
if calibrated_classifiers:
calibrated = calibrated_classifiers[0]
for attr in ("estimator", "base_estimator"):
estimator = getattr(calibrated, attr, None)
if estimator is not None:
return estimator

for attr in ("estimator", "base_estimator"):
estimator = getattr(model, attr, None)
if estimator is not None:
return estimator

return model


def align_categorical_dtypes(
df_target: pd.DataFrame,
df_reference: pd.DataFrame,
Expand Down Expand Up @@ -332,6 +354,8 @@ def guess_shap(model):
Returns:
str: {'tree', 'linear', None}
"""
model = unwrap_calibrated_classifier(model)

tree_models = [
"RandomForestClassifier",
"RandomForestRegressor",
Expand Down
35 changes: 21 additions & 14 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def __init__(
self.X_background = None
if not hasattr(self, "model"):
self.model = model
self.model_for_shap = unwrap_calibrated_classifier(self.model)

if safe_isinstance(model, "xgboost.core.Booster"):
raise ValueError(
Expand Down Expand Up @@ -337,7 +338,7 @@ def __init__(
self.shap_kwargs = shap_kwargs or {}

if shap == "guess":
shap_guess = guess_shap(self.model)
shap_guess = guess_shap(self.model_for_shap)
model_str = (
str(type(self.model))
.replace("'", "")
Expand Down Expand Up @@ -1230,13 +1231,14 @@ def shap_explainer(self):
if not hasattr(self, "_shap_explainer"):
X_str = ", X_background" if self.X_background is not None else "X"
NoX_str = ", X_background" if self.X_background is not None else ""
model_for_shap = self.model_for_shap
if self.shap == "tree":
logger.info(
"Generating self.shap_explainer = "
f"shap.TreeExplainer(model{NoX_str})"
)
# Fix XGBoost 3.1+ base_score string format before shap accesses it
model_for_shap = self._fix_xgboost_model_for_shap(self.model)
model_for_shap = self._fix_xgboost_model_for_shap(model_for_shap)
self._shap_explainer = shap.TreeExplainer(model_for_shap)
elif self.shap == "linear":
if self.X_background is None:
Expand All @@ -1250,7 +1252,7 @@ def shap_explainer(self):
X_str,
)
self._shap_explainer = shap.LinearExplainer(
self.model,
model_for_shap,
self.X_background if self.X_background is not None else self.X,
)
elif self.shap == "deep":
Expand All @@ -1263,7 +1265,7 @@ def shap_explainer(self):
UserWarning,
)
self._shap_explainer = shap.DeepExplainer(
self.model,
model_for_shap,
self.X_background
if self.X_background is not None
else shap.sample(self.X, 5),
Expand All @@ -1280,7 +1282,7 @@ def shap_explainer(self):
import torch

self._shap_explainer = shap.DeepExplainer(
self.model.module_,
model_for_shap.module_,
torch.tensor(self.X_background.values)
if self.X_background is not None
else torch.tensor(shap.sample(self.X, 5).values),
Expand Down Expand Up @@ -1327,7 +1329,7 @@ def model_predict(data_asarray):
"Please install a CUDA-enabled SHAP build that includes "
"GPUTree support."
)
self._shap_explainer = explainer_cls(self.model, X_data)
self._shap_explainer = explainer_cls(model_for_shap, X_data)
return self._shap_explainer

@insert_pos_label
Expand Down Expand Up @@ -2861,16 +2863,17 @@ def shap_explainer(self):
Taking into account model type and model_output
"""
if not hasattr(self, "_shap_explainer"):
model_for_shap = self.model_for_shap
model_str = (
str(type(self.model))
str(type(model_for_shap))
.replace("'", "")
.replace("<", "")
.replace(">", "")
.split(".")[-1]
)
if self.shap == "tree":
if safe_isinstance(
self.model,
model_for_shap,
"XGBClassifier",
"LGBMClassifier",
"CatBoostClassifier",
Expand All @@ -2897,7 +2900,9 @@ def shap_explainer(self):
UserWarning,
)
# Fix XGBoost 3.1+ base_score string format before shap accesses it
model_for_shap = self._fix_xgboost_model_for_shap(self.model)
model_for_shap = self._fix_xgboost_model_for_shap(
model_for_shap
)
self._shap_explainer = shap.TreeExplainer(
model_for_shap,
self.X_background
Expand All @@ -2914,7 +2919,9 @@ def shap_explainer(self):
", X_background" if self.X_background is not None else "",
)
# Fix XGBoost 3.1+ base_score string format before shap accesses it
model_for_shap = self._fix_xgboost_model_for_shap(self.model)
model_for_shap = self._fix_xgboost_model_for_shap(
model_for_shap
)
self._shap_explainer = shap.TreeExplainer(
model_for_shap, self.X_background
)
Expand All @@ -2929,7 +2936,7 @@ def shap_explainer(self):
", X_background" if self.X_background is not None else "",
)
# Fix XGBoost 3.1+ base_score string format before shap accesses it
model_for_shap = self._fix_xgboost_model_for_shap(self.model)
model_for_shap = self._fix_xgboost_model_for_shap(model_for_shap)
self._shap_explainer = shap.TreeExplainer(
model_for_shap, self.X_background
)
Expand All @@ -2955,7 +2962,7 @@ def shap_explainer(self):
)

self._shap_explainer = shap.LinearExplainer(
self.model,
model_for_shap,
self.X_background if self.X_background is not None else self.X,
)
elif self.shap == "deep":
Expand All @@ -2968,7 +2975,7 @@ def shap_explainer(self):
UserWarning,
)
self._shap_explainer = shap.DeepExplainer(
self.model,
model_for_shap,
self.X_background
if self.X_background is not None
else shap.sample(self.X, 5),
Expand All @@ -2985,7 +2992,7 @@ def shap_explainer(self):
UserWarning,
)
self._shap_explainer = shap.DeepExplainer(
self.model.module_,
model_for_shap.module_,
torch.tensor(
self.X_background.values
if self.X_background is not None
Expand Down
12 changes: 12 additions & 0 deletions tests/test_classifier_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import plotly.graph_objects as go

from sklearn.calibration import CalibratedClassifierCV
from sklearn.ensemble import RandomForestClassifier

from explainerdashboard import ClassifierExplainer, ExplainerDashboard
Expand Down Expand Up @@ -110,6 +111,17 @@ def test_string_labels_supported(classifier_data):
assert {"pred_proba", "y"}.issubset(lift_df.columns)


def test_calibrated_classifiercv_uses_tree_shap(classifier_data):
X_train, y_train, X_test, y_test = classifier_data
base_estimator = RandomForestClassifier(n_estimators=25, random_state=0)
model = CalibratedClassifierCV(estimator=base_estimator, cv=2)
model.fit(X_train, y_train)

explainer = ClassifierExplainer(model, X_test, y_test)

assert explainer.shap == "tree"


def test_row_from_input(precalculated_rf_classifier_explainer):
input_row = precalculated_rf_classifier_explainer.get_row_from_input(
precalculated_rf_classifier_explainer.X.iloc[[0]].values.tolist()
Expand Down
Loading