Skip to content

Commit bf89386

Browse files
committed
add argument memory_tolerance to enable more flexible memory control
1 parent 9763d57 commit bf89386

3 files changed

Lines changed: 12 additions & 3 deletions

File tree

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ brew install libomp
5858

5959
## Usage
6060

61-
The following screenshot shows a typical use case of FastTreeSHAP on [Census Income Data](https://archive.ics.uci.edu/ml/datasets/census+income). Note that the usage of FastTreeSHAP is exactly the same as the usage of [SHAP](https://github.com/slundberg/shap), except for three additional arguments in the class `TreeExplainer`: `algorithm`, `n_jobs`, and `shortcut`.
61+
The following screenshot shows a typical use case of FastTreeSHAP on [Census Income Data](https://archive.ics.uci.edu/ml/datasets/census+income). Note that the usage of FastTreeSHAP is exactly the same as the usage of [SHAP](https://github.com/slundberg/shap), except for four additional arguments in the class `TreeExplainer`: `algorithm`, `n_jobs`, `memory_tolerance`, and `shortcut`.
6262

6363
`algorithm`: This argument specifies the TreeSHAP algorithm used to run FastTreeSHAP. It can take values `"v0"`, `"v1"`, `"v2"` or `"auto"`, and its default value is `"auto"`:
6464
* `"v0"`: Original TreeSHAP algorithm in [SHAP](https://github.com/slundberg/shap) package.
@@ -68,6 +68,8 @@ The following screenshot shows a typical use case of FastTreeSHAP on [Census Inc
6868

6969
`n_jobs`: This argument specifies the number of parallel threads used to run FastTreeSHAP. It can take values `-1` or a positive integer. Its default value is `-1`, which means utilizing all available cores in parallel computing.
7070

71+
`memory_tolerance`: This argument specifies the upper limit of memory allocation (in GB) to run FastTreeSHAP v2. It can take values `-1` or a positive number. Its default value is `-1`, which means allocating a maximum of 0.25 * total memory of the machine to run FastTreeSHAP v2.
72+
7173
`shortcut`: This argument determines whether to use the TreeSHAP algorithm embedded in [XGBoost](https://github.com/dmlc/xgboost), [LightGBM](https://github.com/microsoft/LightGBM), and [CatBoost](https://github.com/catboost/catboost) packages directly when computing SHAP values for XGBoost, LightGBM, and CatBoost models and when computing SHAP interaction values for XGBoost models. Its default value is `False`, which means bypassing the "shortcut" and using the code in FastTreeSHAP package directly to compute SHAP values for XGBoost, LightGBM, and CatBoost models. Note that currently `shortcut` is automaticaly set to be True for CatBoost model, as we are working on CatBoost component in FastTreeSHAP package. More details of the usage of "shortcut" can be found in the notebooks [Census Income](notebooks/FastTreeSHAP_Census_Income.ipynb), [Superconductor](notebooks/FastTreeSHAP_Superconductor.ipynb), and [Crop Mapping](notebooks/FastTreeSHAP_Crop_Mapping.ipynb).
7274

7375
![FastTreeSHAP Adult Screenshot1](docs/images/fasttreeshap_adult_screenshot1.png)

fasttreeshap/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
import sys
55

6-
__version__ = '0.1.1'
6+
__version__ = '0.1.2'
77

88
# check python version
99
if (sys.version_info < (3, 0)):

fasttreeshap/explainers/_tree.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class Tree(Explainer):
6060
implementations either inside an externel model package or in the local compiled C extention.
6161
"""
6262

63-
def __init__(self, model, data = None, model_output="raw", feature_perturbation="interventional", algorithm="auto", n_jobs=-1, feature_names=None, approximate=False, shortcut=False, **deprecated_options):
63+
def __init__(self, model, data = None, model_output="raw", feature_perturbation="tree_path_dependent", algorithm="auto", n_jobs=-1, memory_tolerance=-1, feature_names=None, approximate=False, shortcut=False, **deprecated_options):
6464
""" Build a new Tree explainer for the passed model.
6565
6666
Parameters
@@ -99,6 +99,10 @@ def __init__(self, model, data = None, model_output="raw", feature_perturbation=
9999
all available cores in parallel computing (Setting OMP_NUM_THREADS is unnecessary since n_jobs will
100100
overwrite this parameter).
101101
102+
memory_tolerance : -1 (default), or a positive number
103+
Upper limit of memory allocation (in GB) to run Fast TreeSHAP v2. The default value of memory_tolerance is -1,
104+
which allocates a maximum of 0.25 * total memory of the machine to run Fast TreeSHAP v2.
105+
102106
model_output : "raw", "probability", "log_loss", or model method name
103107
What output of the model should be explained. If "raw" then we explain the raw output of the
104108
trees, which varies by model. For regression models "raw" is the standard output, for binary
@@ -182,6 +186,7 @@ def __init__(self, model, data = None, model_output="raw", feature_perturbation=
182186
self.n_jobs = 1
183187
else:
184188
self.n_jobs = min(int(n_jobs), os.cpu_count())
189+
self.memory_tolerance = memory_tolerance
185190
self.expected_value = None
186191
self.model = TreeEnsemble(model, self.data, self.data_missing, model_output)
187192
self.model_output = model_output
@@ -490,6 +495,8 @@ def _memory_check(self, X):
490495
memory_tolerance = 0.25 * psutil.virtual_memory().total
491496
except:
492497
memory_tolerance = 4294967296 # 4GB
498+
if self.memory_tolerance > 0:
499+
memory_tolerance = min(memory_tolerance, self.memory_tolerance * 1073741824)
493500
return memory_usage_1 <= memory_tolerance, memory_usage_2 <= memory_tolerance
494501

495502

0 commit comments

Comments
 (0)