Skip to content

Commit b08edb7

Browse files
committed
fix beeswarm plot colorbar issue (#18)
1 parent b798b1c commit b08edb7

2 files changed

Lines changed: 14 additions & 14 deletions

File tree

fasttreeshap/plots/_beeswarm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,14 +361,14 @@ def beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0),
361361
import matplotlib.cm as cm
362362
m = cm.ScalarMappable(cmap=color)
363363
m.set_array([0, 1])
364-
cb = pl.colorbar(m, ticks=[0, 1], aspect=1000)
364+
cb = pl.colorbar(m, ax=pl.gca(), ticks=[0, 1], aspect=80)
365365
cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
366366
cb.set_label(color_bar_label, size=12, labelpad=0)
367367
cb.ax.tick_params(labelsize=11, length=0)
368368
cb.set_alpha(1)
369369
cb.outline.set_visible(False)
370-
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
371-
cb.ax.set_aspect((bbox.height - 0.9) * 20)
370+
# bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
371+
# cb.ax.set_aspect((bbox.height - 0.9) * 20)
372372
# cb.draw_all()
373373

374374
pl.gca().xaxis.set_ticks_position('bottom')
@@ -861,14 +861,14 @@ def summary_legacy(shap_values, features=None, feature_names=None, max_display=N
861861
import matplotlib.cm as cm
862862
m = cm.ScalarMappable(cmap=cmap if plot_type != "layered_violin" else pl.get_cmap(color))
863863
m.set_array([0, 1])
864-
cb = pl.colorbar(m, ticks=[0, 1], aspect=1000)
864+
cb = pl.colorbar(m, ax=pl.gca(), ticks=[0, 1], aspect=80)
865865
cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']])
866866
cb.set_label(color_bar_label, size=12, labelpad=0)
867867
cb.ax.tick_params(labelsize=11, length=0)
868868
cb.set_alpha(1)
869869
cb.outline.set_visible(False)
870-
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
871-
cb.ax.set_aspect((bbox.height - 0.9) * 20)
870+
# bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
871+
# cb.ax.set_aspect((bbox.height - 0.9) * 20)
872872
# cb.draw_all()
873873

874874
pl.gca().xaxis.set_ticks_position('bottom')

fasttreeshap/plots/_scatter.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -343,19 +343,19 @@ def scatter(shap_values, color="#1E88E5", hist=True, axis_color="#333333", cmap=
343343
tick_positions = np.array([cname_map[n] for n in cnames])
344344
tick_positions *= 1 - 1 / len(cnames)
345345
tick_positions += 0.5 * (chigh - clow) / (chigh - clow + 1)
346-
cb = pl.colorbar(p, ticks=tick_positions, ax=ax)
346+
cb = pl.colorbar(p, ticks=tick_positions, ax=ax, aspect=80)
347347
cb.set_ticklabels(cnames)
348348
else:
349-
cb = pl.colorbar(p, ax=ax)
349+
cb = pl.colorbar(p, ax=ax, aspect=80)
350350

351351
cb.set_label(feature_names[interaction_index], size=13)
352352
cb.ax.tick_params(labelsize=11)
353353
if categorical_interaction:
354354
cb.ax.tick_params(length=0)
355355
cb.set_alpha(1)
356356
cb.outline.set_visible(False)
357-
bbox = cb.ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
358-
cb.ax.set_aspect((bbox.height - 0.7) * 20)
357+
# bbox = cb.ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
358+
# cb.ax.set_aspect((bbox.height - 0.7) * 20)
359359

360360
# handles any setting of xmax and xmin
361361
# note that we handle None,float, or "percentile(float)" formats
@@ -690,19 +690,19 @@ def dependence_legacy(ind, shap_values=None, features=None, feature_names=None,
690690
if len(tick_positions) == 2:
691691
tick_positions[0] -= 0.25
692692
tick_positions[1] += 0.25
693-
cb = pl.colorbar(p, ticks=tick_positions, ax=ax)
693+
cb = pl.colorbar(p, ticks=tick_positions, ax=ax, aspect=80)
694694
cb.set_ticklabels(cnames)
695695
else:
696-
cb = pl.colorbar(p, ax=ax)
696+
cb = pl.colorbar(p, ax=ax, aspect=80)
697697

698698
cb.set_label(feature_names[interaction_index], size=13)
699699
cb.ax.tick_params(labelsize=11)
700700
if categorical_interaction:
701701
cb.ax.tick_params(length=0)
702702
cb.set_alpha(1)
703703
cb.outline.set_visible(False)
704-
bbox = cb.ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
705-
cb.ax.set_aspect((bbox.height - 0.7) * 20)
704+
# bbox = cb.ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
705+
# cb.ax.set_aspect((bbox.height - 0.7) * 20)
706706

707707
# handles any setting of xmax and xmin
708708
# note that we handle None,float, or "percentile(float)" formats

0 commit comments

Comments
 (0)