diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 8e84a73..3386f66 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -7,6 +7,7 @@ - Fix FeatureInputComponent range calculation for boolean columns (avoid np.round on bools) and add a regression test. - Ensure save_html includes custom tabs by providing a static-export fallback for tabs without a to_html implementation. - Support string class labels in ClassifierExplainer by preserving label mappings and avoiding float casts. +- Support CalibratedClassifierCV by using its fitted base estimator for SHAP (avoids falling back to kernel). ### Improvements - Replace print statements with standard logging and warnings; progress messages are now INFO-level and user-actionable guidance uses warnings. A one-time warning is emitted if logging is not configured, with instructions to call `enable_default_logging()`. diff --git a/TODO.md b/TODO.md index 56d9ece..8328f8c 100644 --- a/TODO.md +++ b/TODO.md @@ -6,11 +6,9 @@ - Rules: link an issue when possible; include size S/M/L; mark blockers. **Now** -- [M/L][Explainers][#279] support CalibratedClassifierCV (unwrap estimator for SHAP; update logic and tests). +- [S/M][Components][#277] whatif input range/rounding customization. **Next** -- [S/M][Components][#277] whatif input range/rounding customization. -- [S/M][Explainers][#274] support string labels without float casts. - [M][Explainers][#273] categorical columns with NaNs: sorting and column preservation. - [S][Explainers][#270] Autogluon integration (coerce predict_proba to ndarray). - [M][Hub][#269] add_dashboard endpoint fails after first request (Flask blueprint lifecycle). diff --git a/explainerdashboard/explainer_methods.py b/explainerdashboard/explainer_methods.py index 74a6422..02005a0 100644 --- a/explainerdashboard/explainer_methods.py +++ b/explainerdashboard/explainer_methods.py @@ -2,6 +2,7 @@ "IndexNotFoundError", "append_dict_to_df", "safe_isinstance", + "unwrap_calibrated_classifier", "align_categorical_dtypes", "guess_shap", "mape_score", @@ -296,6 +297,27 @@ def safe_isinstance(obj, *instance_str): return False +def unwrap_calibrated_classifier(model): + """Return the fitted base estimator for a CalibratedClassifierCV model.""" + if not safe_isinstance(model, "CalibratedClassifierCV"): + return model + + calibrated_classifiers = getattr(model, "calibrated_classifiers_", None) + if calibrated_classifiers: + calibrated = calibrated_classifiers[0] + for attr in ("estimator", "base_estimator"): + estimator = getattr(calibrated, attr, None) + if estimator is not None: + return estimator + + for attr in ("estimator", "base_estimator"): + estimator = getattr(model, attr, None) + if estimator is not None: + return estimator + + return model + + def align_categorical_dtypes( df_target: pd.DataFrame, df_reference: pd.DataFrame, @@ -332,6 +354,8 @@ def guess_shap(model): Returns: str: {'tree', 'linear', None} """ + model = unwrap_calibrated_classifier(model) + tree_models = [ "RandomForestClassifier", "RandomForestRegressor", diff --git a/explainerdashboard/explainers.py b/explainerdashboard/explainers.py index bf91046..e07b695 100644 --- a/explainerdashboard/explainers.py +++ b/explainerdashboard/explainers.py @@ -250,6 +250,7 @@ def __init__( self.X_background = None if not hasattr(self, "model"): self.model = model + self.model_for_shap = unwrap_calibrated_classifier(self.model) if safe_isinstance(model, "xgboost.core.Booster"): raise ValueError( @@ -337,7 +338,7 @@ def __init__( self.shap_kwargs = shap_kwargs or {} if shap == "guess": - shap_guess = guess_shap(self.model) + shap_guess = guess_shap(self.model_for_shap) model_str = ( str(type(self.model)) .replace("'", "") @@ -1230,13 +1231,14 @@ def shap_explainer(self): if not hasattr(self, "_shap_explainer"): X_str = ", X_background" if self.X_background is not None else "X" NoX_str = ", X_background" if self.X_background is not None else "" + model_for_shap = self.model_for_shap if self.shap == "tree": logger.info( "Generating self.shap_explainer = " f"shap.TreeExplainer(model{NoX_str})" ) # Fix XGBoost 3.1+ base_score string format before shap accesses it - model_for_shap = self._fix_xgboost_model_for_shap(self.model) + model_for_shap = self._fix_xgboost_model_for_shap(model_for_shap) self._shap_explainer = shap.TreeExplainer(model_for_shap) elif self.shap == "linear": if self.X_background is None: @@ -1250,7 +1252,7 @@ def shap_explainer(self): X_str, ) self._shap_explainer = shap.LinearExplainer( - self.model, + model_for_shap, self.X_background if self.X_background is not None else self.X, ) elif self.shap == "deep": @@ -1263,7 +1265,7 @@ def shap_explainer(self): UserWarning, ) self._shap_explainer = shap.DeepExplainer( - self.model, + model_for_shap, self.X_background if self.X_background is not None else shap.sample(self.X, 5), @@ -1280,7 +1282,7 @@ def shap_explainer(self): import torch self._shap_explainer = shap.DeepExplainer( - self.model.module_, + model_for_shap.module_, torch.tensor(self.X_background.values) if self.X_background is not None else torch.tensor(shap.sample(self.X, 5).values), @@ -1327,7 +1329,7 @@ def model_predict(data_asarray): "Please install a CUDA-enabled SHAP build that includes " "GPUTree support." ) - self._shap_explainer = explainer_cls(self.model, X_data) + self._shap_explainer = explainer_cls(model_for_shap, X_data) return self._shap_explainer @insert_pos_label @@ -2861,8 +2863,9 @@ def shap_explainer(self): Taking into account model type and model_output """ if not hasattr(self, "_shap_explainer"): + model_for_shap = self.model_for_shap model_str = ( - str(type(self.model)) + str(type(model_for_shap)) .replace("'", "") .replace("<", "") .replace(">", "") @@ -2870,7 +2873,7 @@ def shap_explainer(self): ) if self.shap == "tree": if safe_isinstance( - self.model, + model_for_shap, "XGBClassifier", "LGBMClassifier", "CatBoostClassifier", @@ -2897,7 +2900,9 @@ def shap_explainer(self): UserWarning, ) # Fix XGBoost 3.1+ base_score string format before shap accesses it - model_for_shap = self._fix_xgboost_model_for_shap(self.model) + model_for_shap = self._fix_xgboost_model_for_shap( + model_for_shap + ) self._shap_explainer = shap.TreeExplainer( model_for_shap, self.X_background @@ -2914,7 +2919,9 @@ def shap_explainer(self): ", X_background" if self.X_background is not None else "", ) # Fix XGBoost 3.1+ base_score string format before shap accesses it - model_for_shap = self._fix_xgboost_model_for_shap(self.model) + model_for_shap = self._fix_xgboost_model_for_shap( + model_for_shap + ) self._shap_explainer = shap.TreeExplainer( model_for_shap, self.X_background ) @@ -2929,7 +2936,7 @@ def shap_explainer(self): ", X_background" if self.X_background is not None else "", ) # Fix XGBoost 3.1+ base_score string format before shap accesses it - model_for_shap = self._fix_xgboost_model_for_shap(self.model) + model_for_shap = self._fix_xgboost_model_for_shap(model_for_shap) self._shap_explainer = shap.TreeExplainer( model_for_shap, self.X_background ) @@ -2955,7 +2962,7 @@ def shap_explainer(self): ) self._shap_explainer = shap.LinearExplainer( - self.model, + model_for_shap, self.X_background if self.X_background is not None else self.X, ) elif self.shap == "deep": @@ -2968,7 +2975,7 @@ def shap_explainer(self): UserWarning, ) self._shap_explainer = shap.DeepExplainer( - self.model, + model_for_shap, self.X_background if self.X_background is not None else shap.sample(self.X, 5), @@ -2985,7 +2992,7 @@ def shap_explainer(self): UserWarning, ) self._shap_explainer = shap.DeepExplainer( - self.model.module_, + model_for_shap.module_, torch.tensor( self.X_background.values if self.X_background is not None diff --git a/tests/test_classifier_base.py b/tests/test_classifier_base.py index 9a4e14f..e6e90df 100644 --- a/tests/test_classifier_base.py +++ b/tests/test_classifier_base.py @@ -7,6 +7,7 @@ import plotly.graph_objects as go +from sklearn.calibration import CalibratedClassifierCV from sklearn.ensemble import RandomForestClassifier from explainerdashboard import ClassifierExplainer, ExplainerDashboard @@ -110,6 +111,17 @@ def test_string_labels_supported(classifier_data): assert {"pred_proba", "y"}.issubset(lift_df.columns) +def test_calibrated_classifiercv_uses_tree_shap(classifier_data): + X_train, y_train, X_test, y_test = classifier_data + base_estimator = RandomForestClassifier(n_estimators=25, random_state=0) + model = CalibratedClassifierCV(estimator=base_estimator, cv=2) + model.fit(X_train, y_train) + + explainer = ClassifierExplainer(model, X_test, y_test) + + assert explainer.shap == "tree" + + def test_row_from_input(precalculated_rf_classifier_explainer): input_row = precalculated_rf_classifier_explainer.get_row_from_input( precalculated_rf_classifier_explainer.X.iloc[[0]].values.tolist()