Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions onnxmltools/convert/common/tree_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,21 @@ def add_node(
if mode == "LEAF":
flattened_weights = weights.flatten()
factor = tree_weight
# If the values stored at leaves are counts of possible classes, we need convert them to probabilities by
# doing a normalization.
# If the values stored at leaves are counts of possible classes, we need
# convert them to probabilities by doing a normalization.
if leaf_weights_are_counts:
s = sum(flattened_weights)
factor /= float(s) if s != 0.0 else 1.0
flattened_weights = [w * factor for w in flattened_weights]
if len(flattened_weights) == 2 and is_classifier:
flattened_weights = [flattened_weights[1]]

# Note that attribute names for making prediction are different for classifiers and regressors
# Previously, binary classifiers dropped class-0 and stored only the
# class-1 weight at class_id=0, relying on an old onnxruntime behaviour
# that inferred the complementary probability. onnxruntime >=1.22
# interprets class_id literally, so that shortcut now produces wrong
# (negated) probabilities. Always emit every class weight explicitly.

# Note that attribute names for making prediction are different for
# classifiers and regressors
if is_classifier:
for i, w in enumerate(flattened_weights):
attr_pairs["class_treeids"].append(tree_id)
Expand Down Expand Up @@ -160,4 +165,4 @@ def _process_process_tree_attributes(attrs):
"Unexpected type for one or several attributes:\n" + "\n".join(wrong_types)
)
if update:
attrs.update(update)
attrs.update(update)
4 changes: 2 additions & 2 deletions onnxmltools/convert/lightgbm/operator_converters/LightGbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,9 @@ def convert_lightgbm(scope, operator, container):
if "num_class" in gbm_text:
n_classes = gbm_text["num_class"]
if n_classes == 1:
attrs["post_transform"] = "LOGISTIC"
attrs["post_transform"] = "LOGISTIC" # binary → needs sigmoid
else:
attrs["post_transform"] = "NONE"
attrs["post_transform"] = "NONE" # multiclass → already softmax elsewhere
objective = "binary"
else:
raise NotImplementedError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
sparkml_tree_dataset_to_sklearn,
add_tree_to_attribute_pairs,
get_default_tree_classifier_attribute_pairs,
add_tree_ensemble_classifier_node,
)
from .tree_helper import rewrite_ids_and_process

Expand Down Expand Up @@ -42,12 +43,14 @@ def convert_decision_tree_classifier(scope, operator, container):

new_attrs = rewrite_ids_and_process(attrs, logger)

container.add_node(
op_type,
add_tree_ensemble_classifier_node(
scope,
container,
operator.input_full_names,
[operator.outputs[0].full_name, operator.outputs[1].full_name],
op_domain="ai.onnx.ml",
**new_attrs,
operator.outputs[0].full_name,
operator.outputs[1].full_name,
new_attrs,
op.numClasses,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from ...common.tree_ensemble import (
get_default_tree_classifier_attribute_pairs,
add_tree_to_attribute_pairs,
)
from ...common._registration import register_converter, register_shape_calculator
from .tree_ensemble_common import (
save_read_sparkml_model_data,
sparkml_tree_dataset_to_sklearn,
get_default_tree_classifier_attribute_pairs,
add_tree_to_attribute_pairs,
add_tree_ensemble_classifier_node,
)
from .decision_tree_classifier import calculate_decision_tree_classifier_output_shapes
from .tree_helper import rewrite_ids_and_process
Expand Down Expand Up @@ -46,12 +44,14 @@ def convert_random_forest_classifier(scope, operator, container):
if isinstance(v, list) and k not in {"classlabels_int64s"}:
main_attr_pairs[k].extend(v)

container.add_node(
op_type,
add_tree_ensemble_classifier_node(
scope,
container,
operator.input_full_names,
[operator.outputs[0].full_name, operator.outputs[1].full_name],
op_domain="ai.onnx.ml",
**main_attr_pairs,
operator.outputs[0].full_name,
operator.outputs[1].full_name,
main_attr_pairs,
op.numClasses,
)


Expand All @@ -63,4 +63,4 @@ def convert_random_forest_classifier(scope, operator, container):
register_shape_calculator(
"pyspark.ml.classification.RandomForestClassificationModel",
calculate_decision_tree_classifier_output_shapes,
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import numpy
import re
from onnx import TensorProto
from pyspark.sql import SparkSession


Expand Down Expand Up @@ -183,18 +184,21 @@ def add_node(
if mode == "LEAF":
flattened_weights = weights.flatten()
factor = tree_weight
# If the values stored at leaves are counts of possible classes,
# we need convert them to probabilities by
# doing a normalization.
# If the values stored at leaves are counts of possible classes, we need
# convert them to probabilities by doing a normalization.
if leaf_weights_are_counts:
s = sum(flattened_weights)
factor /= float(s) if s != 0.0 else 1.0
flattened_weights = [w * factor for w in flattened_weights]
if len(flattened_weights) == 2 and is_classifier:
flattened_weights = [flattened_weights[1]]

# Note that attribute names for making prediction
# are different for classifiers and regressors
# Previously, binary classifiers dropped class-0 and stored only the
# class-1 weight at class_id=0, relying on an old onnxruntime behaviour
# that inferred the complementary probability. onnxruntime >=1.22
# interprets class_id literally, so that shortcut now produces wrong
# (negated) probabilities. Always emit every class weight explicitly.

# Note that attribute names for making prediction are different for
# classifiers and regressors
if is_classifier:
for i, w in enumerate(flattened_weights):
attr_pairs["class_treeids"].append(tree_id)
Expand Down Expand Up @@ -250,3 +254,58 @@ def add_tree_to_attribute_pairs(
weight_id_bias,
leaf_weights_are_counts,
)


def add_tree_ensemble_classifier_node(
scope, container, input_full_names, label_full_name, prob_full_name, attrs, num_classes
):
"""
Adds a TreeEnsembleClassifier node for the attrs built from
add_tree_to_attribute_pairs/rewrite_ids_and_process.

For binary (2-class) classifiers, onnxruntime's native label output for
this op only looks at whether the explicit class_id=1 score is positive,
ignoring the explicit class_id=0 score - this is wrong whenever leaf
weights are fractional (e.g. averaged across an ensemble of trees, or a
single tree with impure leaves), even though the probability output
itself is computed correctly. The label is instead derived via
ArgMax+Gather over the probability output, which does not have this
issue.
"""
if num_classes == 2:
raw_label_name = scope.get_unique_variable_name("tree_ensemble_raw_label")
output_names = [raw_label_name, prob_full_name]
else:
output_names = [label_full_name, prob_full_name]

container.add_node(
"TreeEnsembleClassifier",
input_full_names,
output_names,
op_domain="ai.onnx.ml",
**attrs,
)

if num_classes == 2:
argmax_name = scope.get_unique_variable_name("tree_ensemble_argmax")
container.add_node(
"ArgMax",
[prob_full_name],
[argmax_name],
axis=1,
keepdims=0,
name=scope.get_unique_operator_name("ArgMax"),
)
labels_name = scope.get_unique_variable_name("tree_ensemble_classlabels")
container.add_initializer(
labels_name,
TensorProto.INT64,
[len(attrs["classlabels_int64s"])],
[int(c) for c in attrs["classlabels_int64s"]],
)
container.add_node(
"Gather",
[labels_name, argmax_name],
[label_full_name],
name=scope.get_unique_operator_name("Gather"),
)
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,12 @@ def to_attrs(self, **kwargs):
field = "nodes_treeids"
for i in range(len(attrs[k])):
nid = attrs[k][i]
if nid == 0 and k in {"nodes_truenodeids", "nodes_falsenodeids"}:
continue
if k in {"nodes_truenodeids", "nodes_falsenodeids"}:
# Skip only genuine placeholder children on LEAF nodes.
# A branch node whose true/false child happens to be node 0
# (e.g. the root's left child) must still be remapped.
if attrs["nodes_modes"][i] == "LEAF":
continue
tid = attrs[field][i]
new_id = new_numbers[tid, nid]
attrs[k][i] = new_id
Expand Down
Loading