@@ -357,34 +357,6 @@ def shap_values(self, X, y=None, tree_limit=None, approximate=False, check_addit
357357 if tree_limit is None :
358358 tree_limit = - 1 if self .model .tree_limit is None else self .model .tree_limit
359359
360- # choose the most appropriate TreeSHAP algorithm
361- if self .algorithm == "auto" :
362- # check if number of samples to be explained is sufficiently large
363- num_samples = X .shape [0 ]
364- num_samples_threshold = 2 ** int (self .model .max_depth + 1 ) / self .model .max_depth
365- num_samples_check = (num_samples >= num_samples_threshold )
366- # check if memory constraint is satisfied (check Section Notes in README.md for justifications of memory check conditions in function _memory_check)
367- memory_check_1 , memory_check_2 = self ._memory_check (X )
368- if num_samples_check and (memory_check_1 or memory_check_2 ):
369- if memory_check_1 :
370- algorithm = "v2_1"
371- else :
372- algorithm = "v2_2"
373- else :
374- algorithm = "v1"
375- else :
376- algorithm = self .algorithm
377- if algorithm == "v2" :
378- # check if memory constraint is satisfied (check Section Notes in README.md for justifications of memory check conditions in function _memory_check)
379- memory_check_1 , memory_check_2 = self ._memory_check (X )
380- if memory_check_1 :
381- algorithm = "v2_1"
382- elif memory_check_2 :
383- algorithm = "v2_2"
384- else :
385- warnings .warn ("There may exist memory issue for algorithm v2. Switched to algorithm v1." )
386- algorithm = "v1"
387-
388360 # shortcut using the C++ version of Tree SHAP in XGBoost, LightGBM, and CatBoost
389361 if self .feature_perturbation == "tree_path_dependent" and self .model .model_type != "internal" and self .data is None :
390362 model_output_vals = None
@@ -447,6 +419,34 @@ def shap_values(self, X, y=None, tree_limit=None, approximate=False, check_addit
447419
448420 return out
449421
422+ # choose the most appropriate TreeSHAP algorithm
423+ if self .algorithm == "auto" :
424+ # check if number of samples to be explained is sufficiently large
425+ num_samples = X .shape [0 ]
426+ num_samples_threshold = 2 ** int (self .model .max_depth + 1 ) / self .model .max_depth
427+ num_samples_check = (num_samples >= num_samples_threshold )
428+ # check if memory constraint is satisfied (check Section Notes in README.md for justifications of memory check conditions in function _memory_check)
429+ memory_check_1 , memory_check_2 = self ._memory_check (X )
430+ if num_samples_check and (memory_check_1 or memory_check_2 ):
431+ if memory_check_1 :
432+ algorithm = "v2_1"
433+ else :
434+ algorithm = "v2_2"
435+ else :
436+ algorithm = "v1"
437+ else :
438+ algorithm = self .algorithm
439+ if algorithm == "v2" :
440+ # check if memory constraint is satisfied (check Section Notes in README.md for justifications of memory check conditions in function _memory_check)
441+ memory_check_1 , memory_check_2 = self ._memory_check (X )
442+ if memory_check_1 :
443+ algorithm = "v2_1"
444+ elif memory_check_2 :
445+ algorithm = "v2_2"
446+ else :
447+ warnings .warn ("There may exist memory issue for algorithm v2. Switched to algorithm v1." )
448+ algorithm = "v1"
449+
450450 X , y , X_missing , flat_output , tree_limit , check_additivity = self ._validate_inputs (X , y ,
451451 tree_limit ,
452452 check_additivity )
0 commit comments