Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions onnxmltools/convert/lightgbm/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _get_lightgbm_operator_name(model):
return lightgbm_operator_name_map[model_type]


def _parse_lightgbm_simple_model(scope, model, inputs, split=None):
def _parse_lightgbm_simple_model(scope, model, inputs, split=None, decision_leaf=False):
"""
This function handles all non-pipeline models.

Expand All @@ -111,11 +111,14 @@ def _parse_lightgbm_simple_model(scope, model, inputs, split=None):
:param inputs: A list of variables
:param split: split TreeEnsembleRegressor into multiple node to reduce
discrepancies
:param decision_leaf: if True, an additional output is added to return
the leaf indices (one per tree) for each input sample
:return: A list of output variables which will be passed to next stage
"""
operator_name = _get_lightgbm_operator_name(model)
this_operator = scope.declare_local_operator(operator_name, model)
this_operator.split = split
this_operator.decision_leaf = decision_leaf
this_operator.inputs = inputs

if operator_name == "LgbmClassifier":
Expand All @@ -133,13 +136,21 @@ def _parse_lightgbm_simple_model(scope, model, inputs, split=None):
# We assume that all scikit-learn operator can only produce a single float tensor.
variable = scope.declare_local_variable("variable", FloatTensorType())
this_operator.outputs.append(variable)
if decision_leaf:
leaf_indices_variable = scope.declare_local_variable(
"leaf_indices", Int64Type()
)
this_operator.outputs.append(leaf_indices_variable)
return this_operator.outputs


def _parse_sklearn_classifier(scope, model, inputs, zipmap=True):
probability_tensor = _parse_lightgbm_simple_model(scope, model, inputs)
def _parse_sklearn_classifier(scope, model, inputs, zipmap=True, decision_leaf=False):
probability_tensor = _parse_lightgbm_simple_model(
scope, model, inputs, decision_leaf=decision_leaf
)
this_operator = scope.declare_local_operator("LgbmZipMap")
this_operator.inputs = probability_tensor
# ZipMap only needs label and probabilities (first two outputs)
this_operator.inputs = probability_tensor[:2]
this_operator.zipmap = zipmap

classes = model.classes_
Expand Down Expand Up @@ -179,10 +190,14 @@ def _parse_sklearn_classifier(scope, model, inputs, zipmap=True):
)
this_operator.outputs.append(output_label)
this_operator.outputs.append(output_probability)
return this_operator.outputs
result = list(this_operator.outputs)
if decision_leaf:
# Pass through the leaf indices variable from the LgbmClassifier operator
result.append(probability_tensor[2])
return result


def _parse_lightgbm(scope, model, inputs, zipmap=True, split=None):
def _parse_lightgbm(scope, model, inputs, zipmap=True, split=None, decision_leaf=False):
"""
This is a delegate function. It doesn't nothing but
invoke the correct parsing function according to the input
Expand All @@ -194,13 +209,21 @@ def _parse_lightgbm(scope, model, inputs, zipmap=True, split=None):
:param zipmap: add operator ZipMap after operator TreeEnsembleClassifier
:param split: split TreeEnsembleRegressor into multiple node to reduce
discrepancies
:param decision_leaf: if True, an additional output is added to return
the leaf indices (one per tree) for each input sample
:return: The output variables produced by the input model
"""
if isinstance(model, LGBMClassifier):
return _parse_sklearn_classifier(scope, model, inputs, zipmap=zipmap)
return _parse_sklearn_classifier(
scope, model, inputs, zipmap=zipmap, decision_leaf=decision_leaf
)
if isinstance(model, WrappedBooster) and model.operator_name == "LgbmClassifier":
return _parse_sklearn_classifier(scope, model, inputs, zipmap=zipmap)
return _parse_lightgbm_simple_model(scope, model, inputs, split=split)
return _parse_sklearn_classifier(
scope, model, inputs, zipmap=zipmap, decision_leaf=decision_leaf
)
return _parse_lightgbm_simple_model(
scope, model, inputs, split=split, decision_leaf=decision_leaf
)


def parse_lightgbm(
Expand All @@ -211,6 +234,7 @@ def parse_lightgbm(
custom_shape_calculators=None,
zipmap=True,
split=None,
decision_leaf=False,
):
raw_model_container = LightGbmModelContainer(model)
topology = Topology(
Expand All @@ -230,7 +254,9 @@ def parse_lightgbm(
for variable in inputs:
raw_model_container.add_input(variable)

outputs = _parse_lightgbm(scope, model, inputs, zipmap=zipmap, split=split)
outputs = _parse_lightgbm(
scope, model, inputs, zipmap=zipmap, split=split, decision_leaf=decision_leaf
)

for variable in outputs:
raw_model_container.add_output(variable)
Expand Down
11 changes: 10 additions & 1 deletion onnxmltools/convert/lightgbm/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def convert(
without_onnx_ml=False,
zipmap=True,
split=None,
decision_leaf=False,
):
"""
This function produces an equivalent ONNX model of the given lightgbm model.
Expand Down Expand Up @@ -65,7 +66,14 @@ def convert(
Parameter *split* is the number of trees per node. It could be possible to
do the same with TreeEnsembleClassifier. However, the normalization of the
probabilities significantly reduces the discrepancies.
to use ONNX-ML operators as well.
:param decision_leaf: if True, an additional output named *leaf_indices* is
added to the ONNX model. It returns an int64 tensor of shape [N, n_trees]
where each element is the index of the leaf node reached in the corresponding
tree for each input sample. This is equivalent to LightGBM's
``predict(X, pred_leaf=True)``.
:param without_onnx_ml: set to True to generate a model composed of ONNX
operators only (requires hummingbird-ml); set to False (default) to allow
ONNX-ML operators as well.
:return: An ONNX model (type: ModelProto) which is equivalent to the input lightgbm model
"""
if initial_types is None:
Expand All @@ -92,6 +100,7 @@ def convert(
custom_shape_calculators,
zipmap=zipmap,
split=split,
decision_leaf=decision_leaf,
)
topology.compile()
onnx_ml_model = convert_topology(
Expand Down
85 changes: 82 additions & 3 deletions onnxmltools/convert/lightgbm/operator_converters/LightGbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,79 @@ def _split_tree_ensemble_atts(attrs, split):
return results


def _add_decision_leaf_output(scope, operator, container, attrs, n_classes, gbm_text):
"""
Adds an additional ONNX ``TreeEnsembleRegressor`` node that produces leaf
indices (one per tree) for each input sample. The output is cast to int64
and written to the last output variable of *operator*.

:param scope: Scope object
:param operator: the operator being converted
:param container: the container to add nodes to
:param attrs: the attribute dict already populated with nodes_* entries
:param n_classes: number of classes / output targets in the main model
:param gbm_text: the raw JSON dump of the booster
"""
n_trees = len(gbm_text["tree_info"])

# Build a TreeEnsembleRegressor that outputs the sequential leaf index
# reached in each tree. One output column per tree (n_targets = n_trees).
leaf_attrs = {k: copy.deepcopy(v) for k, v in attrs.items() if k.startswith("nodes_")}
leaf_attrs["name"] = operator.full_name + "_leaf"
leaf_attrs["n_targets"] = n_trees
leaf_attrs["post_transform"] = "NONE"
leaf_attrs["aggregate_function"] = "SUM"

# For each leaf in the original attrs (class_treeids / class_nodeids),
# map tree i → output column i and assign a sequential leaf index
# within that tree as the weight.
leaf_count_per_tree: dict = {}
target_treeids = []
target_nodeids = []
target_ids = []
target_weights = []

for tree_id, node_id in zip(attrs["class_treeids"], attrs["class_nodeids"]):
count = leaf_count_per_tree.get(tree_id, 0)
target_treeids.append(tree_id)
target_nodeids.append(node_id)
target_ids.append(tree_id)
target_weights.append(float(count))
leaf_count_per_tree[tree_id] = count + 1

leaf_attrs["target_treeids"] = target_treeids
leaf_attrs["target_nodeids"] = target_nodeids
leaf_attrs["target_ids"] = target_ids
leaf_attrs["target_weights"] = target_weights

# onnx does not support mixed int/float lists
update = {}
for k, v in leaf_attrs.items():
if not isinstance(v, list):
continue
tps = set(map(type, v))
if len(tps) == 2 and tps == {int, float}:
update[k] = [float(x) for x in v]
leaf_attrs.update(update)

leaf_output_name = scope.get_unique_variable_name("leaf_output")
container.add_node(
"TreeEnsembleRegressor",
operator.input_full_names,
leaf_output_name,
op_domain="ai.onnx.ml",
**leaf_attrs,
)
# Cast float output to int64 for the leaf_indices output variable
apply_cast(
scope,
leaf_output_name,
operator.outputs[-1].full_name,
container,
to=TensorProto.INT64,
)


def convert_lightgbm(scope, operator, container):
"""
Converters for *lightgbm*.
Expand Down Expand Up @@ -624,6 +697,12 @@ def convert_lightgbm(scope, operator, container):
]
attrs[k] = sorted_list

# Build leaf index output when decision_leaf=True
if getattr(operator, "decision_leaf", False):
_add_decision_leaf_output(
scope, operator, container, attrs, n_classes, gbm_text
)

# Create ONNX object
if objective.startswith("binary") or objective.startswith("multiclass"):
# Prepare label information for both of TreeEnsembleClassifier
Expand Down Expand Up @@ -833,22 +912,22 @@ def convert_lightgbm(scope, operator, container):
apply_div(
scope,
[output_name, denominator_name],
operator.output_full_names,
operator.outputs[0].full_name,
container,
broadcast=1,
)
elif post_transform:
container.add_node(
post_transform,
output_name,
operator.output_full_names,
operator.outputs[0].full_name,
name=scope.get_unique_operator_name(post_transform),
)
else:
container.add_node(
"Identity",
output_name,
operator.output_full_names,
operator.outputs[0].full_name,
name=scope.get_unique_operator_name("Identity"),
)

Expand Down
13 changes: 11 additions & 2 deletions onnxmltools/convert/lightgbm/shape_calculators/Classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,13 @@ def calculate_lightgbm_classifier_output_shapes(operator):

Note that the second case is not allowed as long as ZipMap only produces dictionary.
"""
op = operator.raw_operator
decision_leaf = getattr(operator, "decision_leaf", False)
expected_outputs = 3 if decision_leaf else 2
check_input_and_output_numbers(
operator, input_count_range=1, output_count_range=[1, 2]
operator,
input_count_range=1,
output_count_range=[expected_outputs, expected_outputs],
)
check_input_and_output_types(
operator, good_input_types=[FloatTensorType, Int64TensorType]
Expand All @@ -36,7 +41,7 @@ def calculate_lightgbm_classifier_output_shapes(operator):

N = operator.inputs[0].type.shape[0]

class_labels = operator.raw_operator.classes_
class_labels = op.classes_
if all(isinstance(i, np.ndarray) for i in class_labels):
class_labels = np.concatenate(class_labels)
if all(isinstance(i, str) for i in class_labels):
Expand All @@ -48,6 +53,10 @@ def calculate_lightgbm_classifier_output_shapes(operator):
else:
operator.outputs[1].type = FloatTensorType(shape=[N, 1])

if decision_leaf:
n_trees = op.booster_.num_trees()
operator.outputs[2].type = Int64TensorType(shape=[N, n_trees])


def calculate_lgbm_zipmap(operator):
check_input_and_output_numbers(operator, output_count_range=2)
Expand Down
5 changes: 3 additions & 2 deletions onnxmltools/convert/lightgbm/shape_calculators/Ranker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from ...common._registration import register_shape_calculator
from ...common.shape_calculator import calculate_linear_regressor_output_shapes
from .Regressor import calculate_lightgbm_regressor_output_shapes

register_shape_calculator("LgbmRanker", calculate_lightgbm_regressor_output_shapes)

register_shape_calculator("LgbmRanker", calculate_linear_regressor_output_shapes)
34 changes: 33 additions & 1 deletion onnxmltools/convert/lightgbm/shape_calculators/Regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,37 @@

from ...common._registration import register_shape_calculator
from ...common.shape_calculator import calculate_linear_regressor_output_shapes
from ...common.data_types import FloatTensorType, Int64TensorType


def calculate_lightgbm_regressor_output_shapes(operator):
"""
Allowed input/output patterns are
1. [N, C] ---> [N, 1]
2. [N, C] ---> [N, 1], [N, n_trees] (when decision_leaf=True)

This operator produces a scalar prediction for every example in a batch. If the input batch size is N, the output
shape may be [N, 1]. If decision_leaf is True, a second output of shape [N, n_trees] is produced.
"""
decision_leaf = getattr(operator, "decision_leaf", False)
if decision_leaf:
from ...common.shape_calculator import check_input_and_output_numbers

check_input_and_output_numbers(
operator, input_count_range=1, output_count_range=2
)
N = operator.inputs[0].type.shape[0]
op = operator.raw_operator
if hasattr(op, "n_outputs_"):
nout = op.n_outputs_
else:
nout = 1
operator.outputs[0].type = FloatTensorType([N, nout])
n_trees = op.booster_.num_trees()
operator.outputs[1].type = Int64TensorType(shape=[N, n_trees])
else:
calculate_linear_regressor_output_shapes(operator)


register_shape_calculator("LgbmRegressor", calculate_lightgbm_regressor_output_shapes)

register_shape_calculator("LgbmRegressor", calculate_linear_regressor_output_shapes)
2 changes: 2 additions & 0 deletions onnxmltools/convert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def convert_lightgbm(
without_onnx_ml=False,
zipmap=True,
split=None,
decision_leaf=False,
):
if targeted_onnx is not None:
warnings.warn(
Expand All @@ -237,6 +238,7 @@ def convert_lightgbm(
without_onnx_ml,
zipmap=zipmap,
split=split,
decision_leaf=decision_leaf,
)


Expand Down
Loading
Loading