Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
### Bug Fixes
- Fix FeatureInputComponent range calculation for boolean columns (avoid np.round on bools) and add a regression test.
- Ensure save_html includes custom tabs by providing a static-export fallback for tabs without a to_html implementation.
- Support string class labels in ClassifierExplainer by preserving label mappings and avoiding float casts.

### Improvements
- Replace print statements with standard logging and warnings; progress messages are now INFO-level and user-actionable guidance uses warnings. A one-time warning is emitted if logging is not configured, with instructions to call `enable_default_logging()`.
Expand Down
3 changes: 0 additions & 3 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
- Rules: link an issue when possible; include size S/M/L; mark blockers.

**Now**
- [M][Dashboard][#303] save_html does not include custom tabs; fix HTML export and add regression test.
- [S][Explainers+Dashboard][#302] replace print statements with logging and add configuration docs.
- [M/L][Explainers][#279] support CalibratedClassifierCV (unwrap estimator for SHAP; update logic and tests).
- [S][Components][#224] bool columns should not break rounding; handle bool safely.

**Next**
- [S/M][Components][#277] whatif input range/rounding customization.
Expand Down
20 changes: 16 additions & 4 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,9 @@ def __init__(
"y should be a pd.Series or np.ndarray not a pd.DataFrame!"
)

self.y = pd.Series(y.squeeze()).astype(precision)
self.y = pd.Series(y.squeeze())
if is_numeric_dtype(self.y):
self.y = self.y.astype(precision)
self.y_missing = False
else:
self.y = pd.Series(np.full(len(X), np.nan))
Expand Down Expand Up @@ -2683,8 +2685,18 @@ def __init__(
**dict(labels=labels, pos_label=pos_label),
}

class_values = None
if hasattr(self.model, "classes_"):
class_values = list(self.model.classes_)
elif not self.y_missing:
class_values = sorted(pd.Series(self.y).dropna().unique().tolist())

if not self.y_missing:
self.y = self.y.astype("int16")
if class_values is None:
self.y = self.y.astype("int16")
else:
class_to_index = {cls: idx for idx, cls in enumerate(class_values)}
self.y = self.y.map(class_to_index).astype("int16")
if (
self.categorical_cols
and model_output == "probability"
Expand All @@ -2699,8 +2711,8 @@ def __init__(
self.model_output = "logodds"
if labels is not None:
self.labels = labels
elif hasattr(self.model, "classes_"):
self.labels = [str(cls) for cls in self.model.classes_]
elif class_values is not None:
self.labels = [str(cls) for cls in class_values]
else:
self.labels = [str(i) for i in range(self.y.nunique())]
self.pos_label = pos_label
Expand Down
22 changes: 22 additions & 0 deletions tests/test_classifier_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import plotly.graph_objects as go

from sklearn.ensemble import RandomForestClassifier

from explainerdashboard import ClassifierExplainer, ExplainerDashboard
from explainerdashboard.explainer_methods import IndexNotFoundError

Expand Down Expand Up @@ -88,6 +90,26 @@ def test_cats_notencoded(precalculated_rf_classifier_explainer):
)


def test_string_labels_supported(classifier_data):
X_train, y_train, X_test, y_test = classifier_data
labels = ["No", "Yes"]
label_map = {
value: labels[idx] for idx, value in enumerate(sorted(y_train.unique()))
}
y_train_str = y_train.map(label_map)
y_test_str = y_test.map(label_map)

model = RandomForestClassifier(n_estimators=50, random_state=0)
model.fit(X_train, y_train_str)

explainer = ClassifierExplainer(model, X_test, y_test_str)
lift_df = explainer.get_liftcurve_df()

assert explainer.labels == list(model.classes_)
assert set(explainer.y.unique()).issubset(set(range(len(explainer.labels))))
assert {"pred_proba", "y"}.issubset(lift_df.columns)


def test_row_from_input(precalculated_rf_classifier_explainer):
input_row = precalculated_rf_classifier_explainer.get_row_from_input(
precalculated_rf_classifier_explainer.X.iloc[[0]].values.tolist()
Expand Down
Loading