Skip to content

Commit 9763d57

Browse files
committed
fix issue #7
1 parent 58e8288 commit 9763d57

1 file changed

Lines changed: 28 additions & 28 deletions

File tree

fasttreeshap/explainers/_tree.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -357,34 +357,6 @@ def shap_values(self, X, y=None, tree_limit=None, approximate=False, check_addit
357357
if tree_limit is None:
358358
tree_limit = -1 if self.model.tree_limit is None else self.model.tree_limit
359359

360-
# choose the most appropriate TreeSHAP algorithm
361-
if self.algorithm == "auto":
362-
# check if number of samples to be explained is sufficiently large
363-
num_samples = X.shape[0]
364-
num_samples_threshold = 2**int(self.model.max_depth + 1) / self.model.max_depth
365-
num_samples_check = (num_samples >= num_samples_threshold)
366-
# check if memory constraint is satisfied (check Section Notes in README.md for justifications of memory check conditions in function _memory_check)
367-
memory_check_1, memory_check_2 = self._memory_check(X)
368-
if num_samples_check and (memory_check_1 or memory_check_2):
369-
if memory_check_1:
370-
algorithm = "v2_1"
371-
else:
372-
algorithm = "v2_2"
373-
else:
374-
algorithm = "v1"
375-
else:
376-
algorithm = self.algorithm
377-
if algorithm == "v2":
378-
# check if memory constraint is satisfied (check Section Notes in README.md for justifications of memory check conditions in function _memory_check)
379-
memory_check_1, memory_check_2 = self._memory_check(X)
380-
if memory_check_1:
381-
algorithm = "v2_1"
382-
elif memory_check_2:
383-
algorithm = "v2_2"
384-
else:
385-
warnings.warn("There may exist memory issue for algorithm v2. Switched to algorithm v1.")
386-
algorithm = "v1"
387-
388360
# shortcut using the C++ version of Tree SHAP in XGBoost, LightGBM, and CatBoost
389361
if self.feature_perturbation == "tree_path_dependent" and self.model.model_type != "internal" and self.data is None:
390362
model_output_vals = None
@@ -447,6 +419,34 @@ def shap_values(self, X, y=None, tree_limit=None, approximate=False, check_addit
447419

448420
return out
449421

422+
# choose the most appropriate TreeSHAP algorithm
423+
if self.algorithm == "auto":
424+
# check if number of samples to be explained is sufficiently large
425+
num_samples = X.shape[0]
426+
num_samples_threshold = 2**int(self.model.max_depth + 1) / self.model.max_depth
427+
num_samples_check = (num_samples >= num_samples_threshold)
428+
# check if memory constraint is satisfied (check Section Notes in README.md for justifications of memory check conditions in function _memory_check)
429+
memory_check_1, memory_check_2 = self._memory_check(X)
430+
if num_samples_check and (memory_check_1 or memory_check_2):
431+
if memory_check_1:
432+
algorithm = "v2_1"
433+
else:
434+
algorithm = "v2_2"
435+
else:
436+
algorithm = "v1"
437+
else:
438+
algorithm = self.algorithm
439+
if algorithm == "v2":
440+
# check if memory constraint is satisfied (check Section Notes in README.md for justifications of memory check conditions in function _memory_check)
441+
memory_check_1, memory_check_2 = self._memory_check(X)
442+
if memory_check_1:
443+
algorithm = "v2_1"
444+
elif memory_check_2:
445+
algorithm = "v2_2"
446+
else:
447+
warnings.warn("There may exist memory issue for algorithm v2. Switched to algorithm v1.")
448+
algorithm = "v1"
449+
450450
X, y, X_missing, flat_output, tree_limit, check_additivity = self._validate_inputs(X, y,
451451
tree_limit,
452452
check_additivity)

0 commit comments

Comments
 (0)