Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,6 @@ def convert(scope, operator, container):
)

attr_pairs = XGBRegressorConverter._get_default_tree_attribute_pairs()
if isinstance(base_score, list):
attr_pairs["base_values"] = base_score
else:
attr_pairs["base_values"] = [base_score]

if best_ntree_limit and best_ntree_limit < len(js_trees):
js_trees = js_trees[:best_ntree_limit]
Expand All @@ -410,10 +406,22 @@ def convert(scope, operator, container):

# add nodes
objectives_with_loglink = {"count:poisson", "reg:gamma", "reg:tweedie"}
# For binary:logistic, base_score is stored in probability space but
# leaf values are in log-odds space. We must convert base_score to
# log-odds and then apply a sigmoid to match XGBoost's prediction.
objectives_with_logistic = {"binary:logistic"}
if objective in objectives_with_loglink:
names = [scope.get_unique_variable_name("tree")]
del attr_pairs["base_values"]
elif objective in objectives_with_logistic:
bs_list = base_score if isinstance(base_score, list) else [base_score]
logit_base = [
float(np.log(np.float32(bs) / (1.0 - np.float32(bs)))) for bs in bs_list
]
attr_pairs["base_values"] = logit_base
names = [scope.get_unique_variable_name("tree_logits")]
else:
bs_list = base_score if isinstance(base_score, list) else [base_score]
attr_pairs["base_values"] = bs_list
names = operator.output_full_names
container.add_node(
"TreeEnsembleRegressor",
Expand All @@ -432,6 +440,8 @@ def convert(scope, operator, container):
new_name = scope.get_unique_variable_name("exp")
container.add_node("Exp", names, [new_name])
container.add_node("Mul", [new_name, cst], operator.output_full_names)
elif objective in objectives_with_logistic:
container.add_node("Sigmoid", names, operator.output_full_names)


class XGBClassifierConverter(XGBConverter):
Expand Down
59 changes: 59 additions & 0 deletions tests/xgboost/test_xgboost_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,65 @@


class TestXGBoostIssues(unittest.TestCase):
@unittest.skipIf(XGBRegressor is None, "xgboost is not available")
def test_xgbregressor_binary_logistic_with_subsample(self):
"""
XGBRegressor with binary:logistic and subsample<1.0 should produce
results matching XGBoost's predict() output. The base_score in the
model config is stored in probability space, but leaf values are in
log-odds space; the converter must apply logit(base_score) and then
a Sigmoid node to replicate XGBoost's sigmoid(logit(base) + leaves).
"""
import numpy as np
import pandas as pd
import onnxruntime
from onnxmltools.convert import convert_xgboost
from onnxmltools.convert.common.data_types import FloatTensorType

df = pd.DataFrame(
{
"f1": [1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 1.0, 2.0],
"label": [1, 0, 1, 0, 1, 1, 0, 1],
}
)
X_df = df.drop(columns=["label"])
y = df["label"]
X_np = df["f1"].values.reshape(-1, 1).astype(np.float32)

for subsample in [1.0, 0.95]:
params = {
"max_depth": 1,
"n_estimators": 3,
"subsample": subsample,
"objective": "binary:logistic",
"random_state": 42,
}
model = XGBRegressor(**params)
model.fit(X_df, y)

initial_types = [("f1", FloatTensorType([None, 1]))]
onnx_model = convert_xgboost(
model,
"XGBoostXGBRegressor",
initial_types,
target_opset=13,
)

sess = onnxruntime.InferenceSession(
onnx_model.SerializeToString(),
providers=["CPUExecutionProvider"],
)
onnx_output = sess.run(None, {"f1": X_np})[0]
expected = model.predict(X_df).reshape(-1, 1).astype(np.float32)

np.testing.assert_allclose(
onnx_output,
expected,
rtol=1e-5,
atol=1e-5,
err_msg=f"ONNX output mismatch for subsample={subsample}",
)

@unittest.skipIf(XGBRegressor is None, "xgboost is not available")
def test_issue_676(self):
import json
Expand Down
Loading