diff --git a/backends/mlx/_generated_inspector.py b/backends/mlx/_generated_inspector.py new file mode 100644 index 00000000000..03fcef4e706 --- /dev/null +++ b/backends/mlx/_generated_inspector.py @@ -0,0 +1,902 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# ============================================================================ +# AUTO-GENERATED FILE - DO NOT EDIT MANUALLY +# ============================================================================ +# +# This file was generated from schema.fbs by the MLX delegate code generator. +# +# Source: backends/mlx/serialization/schema.fbs +# Generator: backends/mlx/serialization/generate.py +# +# To regenerate, run from the executorch root: +# python backends/mlx/serialization/generate.py +# +# ============================================================================ + +""" +Auto-generated inspector field mappings for MLX delegate. + +This module provides field metadata for each op node type, enabling +the pte_inspector to parse FlatBuffer op nodes without manually +maintaining field mappings. +""" + +from __future__ import annotations + +from typing import Dict, List, Tuple + + +# Field kinds and their extractors +# Each field is a tuple of (display_name, accessor_name, kind) +# where kind is one of: 'tid', 'vid', 'int_or_vid', 'float_or_vid', +# 'int_list', 'int_or_vid_list', 'tid_list', 'string_list', 'scalar', 'string' + +FieldSpec = Tuple[str, str, str] # (display_name, accessor_name, kind) + + +# Mapping from op node name to list of field specs +OP_NODE_FIELDS: Dict[str, List[FieldSpec]] = { + "NoopNode": [ + ], + "IdCopyNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "AddmmNode": [ + ("mat1", "Mat1", "tid"), + ("mat2", "Mat2", "tid"), + ("out", "Out", "tid"), + ("bias", "Bias", "tid"), + ("alpha", "Alpha", "scalar"), + ("beta", "Beta", "scalar"), + ], + "ItemIntNode": [ + ("x", "X", "tid"), + ("out", "Out", "vid"), + ], + "ExpandDimsNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axis", "Axis", "scalar"), + ], + "TileNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("reps", "Reps", "int_or_vid_list"), + ], + "TakeAlongAxisNode": [ + ("x", "X", "tid"), + ("indices", "Indices", "tid"), + ("out", "Out", "tid"), + ("axis", "Axis", "scalar"), + ], + "TakeNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("index", "Index", "int_or_vid_or_tid"), + ("axis", "Axis", "scalar"), + ], + "RMSNormNode": [ + ("x", "X", "tid"), + ("weight", "Weight", "tid"), + ("out", "Out", "tid"), + ("eps", "Eps", "scalar"), + ], + "LayerNormNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("weight", "Weight", "tid"), + ("bias", "Bias", "tid"), + ("eps", "Eps", "scalar"), + ], + "RopeNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("dims", "Dims", "scalar"), + ("offset", "Offset", "vid_or_tid"), + ("freqs", "Freqs", "tid"), + ("traditional", "Traditional", "scalar"), + ("base", "Base", "scalar"), + ("scale", "Scale", "scalar"), + ], + "SdpaNode": [ + ("q", "Q", "tid"), + ("k", "K", "tid"), + ("v", "V", "tid"), + ("out", "Out", "tid"), + ("scale", "Scale", "scalar"), + ("mask", "Mask", "tid"), + ("causal", "Causal", "scalar"), + ], + "AddNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "AddIntNode": [ + ("a", "A", "int_or_vid"), + ("b", "B", "int_or_vid"), + ("out", "Out", "vid"), + ], + "SubtractIntNode": [ + ("a", "A", "int_or_vid"), + ("b", "B", "int_or_vid"), + ("out", "Out", "vid"), + ], + "MultiplyIntNode": [ + ("a", "A", "int_or_vid"), + ("b", "B", "int_or_vid"), + ("out", "Out", "vid"), + ], + "FloorDivideIntNode": [ + ("a", "A", "int_or_vid"), + ("b", "B", "int_or_vid"), + ("out", "Out", "vid"), + ], + "ModIntNode": [ + ("a", "A", "int_or_vid"), + ("b", "B", "int_or_vid"), + ("out", "Out", "vid"), + ], + "SymSizeNode": [ + ("a", "A", "tid"), + ("dim", "Dim", "scalar"), + ("out", "Out", "vid"), + ], + "MultiplyNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "DivideNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "SubtractNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "Conv1DNode": [ + ("x", "X", "tid"), + ("w", "W", "tid"), + ("out", "Out", "tid"), + ("stride", "Stride", "scalar"), + ("padding", "Padding", "scalar"), + ("dilation", "Dilation", "scalar"), + ("groups", "Groups", "scalar"), + ], + "Conv2DNode": [ + ("x", "X", "tid"), + ("w", "W", "tid"), + ("out", "Out", "tid"), + ("stride_h", "StrideH", "scalar"), + ("stride_w", "StrideW", "scalar"), + ("padding_h", "PaddingH", "scalar"), + ("padding_w", "PaddingW", "scalar"), + ("dilation_h", "DilationH", "scalar"), + ("dilation_w", "DilationW", "scalar"), + ("groups", "Groups", "scalar"), + ], + "Conv3DNode": [ + ("x", "X", "tid"), + ("w", "W", "tid"), + ("out", "Out", "tid"), + ("stride_d", "StrideD", "scalar"), + ("stride_h", "StrideH", "scalar"), + ("stride_w", "StrideW", "scalar"), + ("padding_d", "PaddingD", "scalar"), + ("padding_h", "PaddingH", "scalar"), + ("padding_w", "PaddingW", "scalar"), + ("dilation_d", "DilationD", "scalar"), + ("dilation_h", "DilationH", "scalar"), + ("dilation_w", "DilationW", "scalar"), + ("groups", "Groups", "scalar"), + ], + "ConvTranspose1DNode": [ + ("x", "X", "tid"), + ("w", "W", "tid"), + ("out", "Out", "tid"), + ("stride", "Stride", "scalar"), + ("padding", "Padding", "scalar"), + ("dilation", "Dilation", "scalar"), + ("output_padding", "OutputPadding", "scalar"), + ("groups", "Groups", "scalar"), + ], + "ConvTranspose2DNode": [ + ("x", "X", "tid"), + ("w", "W", "tid"), + ("out", "Out", "tid"), + ("stride_h", "StrideH", "scalar"), + ("stride_w", "StrideW", "scalar"), + ("padding_h", "PaddingH", "scalar"), + ("padding_w", "PaddingW", "scalar"), + ("dilation_h", "DilationH", "scalar"), + ("dilation_w", "DilationW", "scalar"), + ("output_padding_h", "OutputPaddingH", "scalar"), + ("output_padding_w", "OutputPaddingW", "scalar"), + ("groups", "Groups", "scalar"), + ], + "ConvTranspose3DNode": [ + ("x", "X", "tid"), + ("w", "W", "tid"), + ("out", "Out", "tid"), + ("stride_d", "StrideD", "scalar"), + ("stride_h", "StrideH", "scalar"), + ("stride_w", "StrideW", "scalar"), + ("padding_d", "PaddingD", "scalar"), + ("padding_h", "PaddingH", "scalar"), + ("padding_w", "PaddingW", "scalar"), + ("dilation_d", "DilationD", "scalar"), + ("dilation_h", "DilationH", "scalar"), + ("dilation_w", "DilationW", "scalar"), + ("output_padding_d", "OutputPaddingD", "scalar"), + ("output_padding_h", "OutputPaddingH", "scalar"), + ("output_padding_w", "OutputPaddingW", "scalar"), + ("groups", "Groups", "scalar"), + ], + "GeluNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("approximate", "Approximate", "string"), + ], + "ARangeNode": [ + ("out", "Out", "tid"), + ("start", "Start", "int_or_vid"), + ("stop", "Stop", "int_or_vid"), + ("step", "Step", "int_or_vid"), + ("scalar_type", "ScalarType", "scalar"), + ], + "SiluNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "SigmoidNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "TanhNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "SqueezeNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("dims", "Dims", "int_list"), + ], + "SplitNode": [ + ("x", "X", "tid"), + ("outs", "Outs", "tid_list"), + ("sizes", "Sizes", "int_or_vid_list"), + ("axis", "Axis", "scalar"), + ], + "RsqrtNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "MaximumNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "MinimumNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "LogNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "SoftmaxNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axis", "Axis", "scalar"), + ("precise", "Precise", "scalar"), + ], + "BroadcastToNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("shape", "Shape", "int_or_vid_list"), + ], + "PadNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("pad_width", "PadWidth", "int_or_vid_list"), + ("mode", "Mode", "string"), + ("constant_value", "ConstantValue", "scalar"), + ], + "WhereNode": [ + ("condition", "Condition", "tid"), + ("x", "X", "tid"), + ("y", "Y", "tid"), + ("out", "Out", "tid"), + ], + "ReshapeNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("shape", "Shape", "int_or_vid_list"), + ], + "TransposeNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("perm", "Perm", "int_list"), + ], + "AsStridedNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("shape", "Shape", "int_or_vid_list"), + ("strides", "Strides", "int_or_vid_list"), + ("offset", "Offset", "scalar"), + ], + "ContiguousNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "GatherNode": [ + ("x", "X", "tid"), + ("indices", "Indices", "tid_list"), + ("out", "Out", "tid"), + ("axes", "Axes", "int_list"), + ("slice_sizes", "SliceSizes", "int_list"), + ], + "SliceNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axis", "Axis", "int_or_vid"), + ("start", "Start", "int_or_vid"), + ("stop", "Stop", "int_or_vid"), + ("step", "Step", "scalar"), + ], + "AsTypeNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("scalar_type", "ScalarType", "scalar"), + ], + "QuantizedMatmulNode": [ + ("x", "X", "tid"), + ("w", "W", "tid"), + ("scales", "Scales", "tid"), + ("out", "Out", "tid"), + ("biases", "Biases", "tid"), + ("group_size", "GroupSize", "scalar"), + ("bits", "Bits", "scalar"), + ("mode", "Mode", "string"), + ("transpose", "Transpose", "scalar"), + ], + "ScatterAddNode": [ + ("x", "X", "tid"), + ("indices", "Indices", "tid"), + ("updates", "Updates", "tid"), + ("out", "Out", "tid"), + ("axis", "Axis", "scalar"), + ], + "ConcatenateNode": [ + ("tensors", "Tensors", "tid_list"), + ("out", "Out", "tid"), + ("axis", "Axis", "scalar"), + ], + "FullNode": [ + ("out", "Out", "tid"), + ("shape", "Shape", "int_or_vid_list"), + ("v", "V", "float_or_vid"), + ("scalar_type", "ScalarType", "scalar"), + ], + "FullLikeNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("v", "V", "float_or_vid"), + ("scalar_type", "ScalarType", "scalar"), + ], + "ArgmaxNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axis", "Axis", "scalar"), + ("keepdims", "Keepdims", "scalar"), + ], + "SliceUpdateNode": [ + ("dst", "Dst", "tid"), + ("update", "Update", "tid"), + ("out", "Out", "tid"), + ("axis", "Axis", "int_or_vid"), + ("start", "Start", "int_or_vid"), + ("stop", "Stop", "int_or_vid"), + ("step", "Step", "scalar"), + ], + "IndexCopyNode": [ + ("dst", "Dst", "tid"), + ("update", "Update", "tid"), + ("indices", "Indices", "tid"), + ("out", "Out", "tid"), + ("axis", "Axis", "scalar"), + ], + "DequantizeNode": [ + ("w", "W", "tid"), + ("scales", "Scales", "tid"), + ("out", "Out", "tid"), + ("biases", "Biases", "tid"), + ("group_size", "GroupSize", "scalar"), + ("bits", "Bits", "scalar"), + ("mode", "Mode", "string"), + ("global_scale", "GlobalScale", "tid"), + ("dtype", "Dtype", "scalar"), + ], + "LessNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "LessEqualNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "GreaterNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "GreaterEqualNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "EqualNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "NotEqualNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "LogicalNotNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "LogicalAndNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "LogicalOrNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "TriNode": [ + ("out", "Out", "tid"), + ("n", "N", "int_or_vid"), + ("m", "M", "int_or_vid"), + ("k", "K", "scalar"), + ("scalar_type", "ScalarType", "scalar"), + ], + "TrilNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("k", "K", "scalar"), + ], + "TriuNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("k", "K", "scalar"), + ], + "ClipNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("a_min", "AMin", "tid"), + ("a_max", "AMax", "tid"), + ], + "CumsumNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axis", "Axis", "scalar"), + ("reverse", "Reverse", "scalar"), + ("inclusive", "Inclusive", "scalar"), + ], + "StackNode": [ + ("tensors", "Tensors", "tid_list"), + ("out", "Out", "tid"), + ("axis", "Axis", "scalar"), + ], + "SignNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "AnyNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axes", "Axes", "int_list"), + ("keepdims", "Keepdims", "scalar"), + ], + "AllNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axes", "Axes", "int_list"), + ("keepdims", "Keepdims", "scalar"), + ], + "RepeatNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("repeats", "Repeats", "int_or_vid"), + ("axis", "Axis", "scalar"), + ], + "SortNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axis", "Axis", "scalar"), + ], + "ArgsortNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axis", "Axis", "scalar"), + ], + "PartitionNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("kth", "Kth", "int_or_vid"), + ("axis", "Axis", "scalar"), + ], + "ArgPartitionNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("kth", "Kth", "int_or_vid"), + ("axis", "Axis", "scalar"), + ], + "FloorNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "CeilNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "SquareNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "ExpNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "SinNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "CosNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "TanNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "ArcsinNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "ArccosNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "ArctanNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "SinhNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "CoshNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "ArcsinhNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "ArccoshNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "ArctanhNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "Log2Node": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "Log10Node": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "Log1pNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "ErfNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "Expm1Node": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "RoundNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("decimals", "Decimals", "scalar"), + ], + "ReciprocalNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "SqrtNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "AbsNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "NegNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ], + "Atan2Node": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "LogAddExpNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "FloorDivideNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "RemainderNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "PowerNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ], + "LogSumExpNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axes", "Axes", "int_list"), + ("keepdims", "Keepdims", "scalar"), + ], + "SumNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axes", "Axes", "int_list"), + ("keepdims", "Keepdims", "scalar"), + ], + "MeanNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axes", "Axes", "int_list"), + ("keepdims", "Keepdims", "scalar"), + ], + "VarNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axes", "Axes", "int_list"), + ("keepdims", "Keepdims", "scalar"), + ("ddof", "Ddof", "scalar"), + ], + "StdNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axes", "Axes", "int_list"), + ("keepdims", "Keepdims", "scalar"), + ("ddof", "Ddof", "scalar"), + ], + "ProdNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axes", "Axes", "int_list"), + ("keepdims", "Keepdims", "scalar"), + ], + "MaxNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axes", "Axes", "int_list"), + ("keepdims", "Keepdims", "scalar"), + ], + "MinNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axes", "Axes", "int_list"), + ("keepdims", "Keepdims", "scalar"), + ], + "ArgminNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axis", "Axis", "scalar"), + ("keepdims", "Keepdims", "scalar"), + ], + "MedianNode": [ + ("x", "X", "tid"), + ("out", "Out", "tid"), + ("axes", "Axes", "int_list"), + ("keepdims", "Keepdims", "scalar"), + ], + "GatherMmNode": [ + ("a", "A", "tid"), + ("b", "B", "tid"), + ("out", "Out", "tid"), + ("lhs_indices", "LhsIndices", "tid"), + ("rhs_indices", "RhsIndices", "tid"), + ("sorted_indices", "SortedIndices", "scalar"), + ], + "GatherQmmNode": [ + ("x", "X", "tid"), + ("w", "W", "tid"), + ("scales", "Scales", "tid"), + ("out", "Out", "tid"), + ("mode", "Mode", "string"), + ("biases", "Biases", "tid"), + ("lhs_indices", "LhsIndices", "tid"), + ("rhs_indices", "RhsIndices", "tid"), + ("transpose", "Transpose", "scalar"), + ("group_size", "GroupSize", "scalar"), + ("bits", "Bits", "scalar"), + ("sorted_indices", "SortedIndices", "scalar"), + ], + "ScanNode": [ + ("originals", "Originals", "tid_list"), + ("sliced", "Sliced", "tid_list"), + ("outputs", "Outputs", "tid_list"), + ("carry", "Carry", "tid_list"), + ("body_chain_idx", "BodyChainIdx", "scalar"), + ("scan_axis", "ScanAxis", "scalar"), + ], +} + + +# List of all op node names (for dynamic imports) +OP_NODE_NAMES: List[str] = [ + "NoopNode", + "IdCopyNode", + "AddmmNode", + "ItemIntNode", + "ExpandDimsNode", + "TileNode", + "TakeAlongAxisNode", + "TakeNode", + "RMSNormNode", + "LayerNormNode", + "RopeNode", + "SdpaNode", + "AddNode", + "AddIntNode", + "SubtractIntNode", + "MultiplyIntNode", + "FloorDivideIntNode", + "ModIntNode", + "SymSizeNode", + "MultiplyNode", + "DivideNode", + "SubtractNode", + "Conv1DNode", + "Conv2DNode", + "Conv3DNode", + "ConvTranspose1DNode", + "ConvTranspose2DNode", + "ConvTranspose3DNode", + "GeluNode", + "ARangeNode", + "SiluNode", + "SigmoidNode", + "TanhNode", + "SqueezeNode", + "SplitNode", + "RsqrtNode", + "MaximumNode", + "MinimumNode", + "LogNode", + "SoftmaxNode", + "BroadcastToNode", + "PadNode", + "WhereNode", + "ReshapeNode", + "TransposeNode", + "AsStridedNode", + "ContiguousNode", + "GatherNode", + "SliceNode", + "AsTypeNode", + "QuantizedMatmulNode", + "ScatterAddNode", + "ConcatenateNode", + "FullNode", + "FullLikeNode", + "ArgmaxNode", + "SliceUpdateNode", + "IndexCopyNode", + "DequantizeNode", + "LessNode", + "LessEqualNode", + "GreaterNode", + "GreaterEqualNode", + "EqualNode", + "NotEqualNode", + "LogicalNotNode", + "LogicalAndNode", + "LogicalOrNode", + "TriNode", + "TrilNode", + "TriuNode", + "ClipNode", + "CumsumNode", + "StackNode", + "SignNode", + "AnyNode", + "AllNode", + "RepeatNode", + "SortNode", + "ArgsortNode", + "PartitionNode", + "ArgPartitionNode", + "FloorNode", + "CeilNode", + "SquareNode", + "ExpNode", + "SinNode", + "CosNode", + "TanNode", + "ArcsinNode", + "ArccosNode", + "ArctanNode", + "SinhNode", + "CoshNode", + "ArcsinhNode", + "ArccoshNode", + "ArctanhNode", + "Log2Node", + "Log10Node", + "Log1pNode", + "ErfNode", + "Expm1Node", + "RoundNode", + "ReciprocalNode", + "SqrtNode", + "AbsNode", + "NegNode", + "Atan2Node", + "LogAddExpNode", + "FloorDivideNode", + "RemainderNode", + "PowerNode", + "LogSumExpNode", + "SumNode", + "MeanNode", + "VarNode", + "StdNode", + "ProdNode", + "MaxNode", + "MinNode", + "ArgminNode", + "MedianNode", + "GatherMmNode", + "GatherQmmNode", + "ScanNode", +] diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 9fa08ab722d..da59b0aa661 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -1395,6 +1395,12 @@ exec_logical_or(const LogicalOrNode& n, ExecutionState& st, StreamOrDevice s) { n.out, logical_or(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); } +inline void +exec_bitwise_or(const BitwiseOrNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, bitwise_or(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + inline void exec_tri(const TriNode& n, ExecutionState& st, StreamOrDevice s) { int rows = resolve_int(n.n, st); int cols = resolve_int(n.m, st); @@ -2034,6 +2040,9 @@ class Interpreter { case OpCode::LOGICAL_OR: ops::exec_logical_or(std::get(instr.node), st, s); break; + case OpCode::BITWISE_OR: + ops::exec_bitwise_or(std::get(instr.node), st, s); + break; case OpCode::TRI: ops::exec_tri(std::get(instr.node), st, s); break; diff --git a/backends/mlx/runtime/MLXLoader.cpp b/backends/mlx/runtime/MLXLoader.cpp new file mode 100644 index 00000000000..93494ea5692 --- /dev/null +++ b/backends/mlx/runtime/MLXLoader.cpp @@ -0,0 +1,2295 @@ +// +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// ============================================================================ +// AUTO-GENERATED FILE - DO NOT EDIT MANUALLY +// ============================================================================ +// +// This file was generated from schema.fbs by the MLX delegate code generator. +// +// Source: backends/mlx/serialization/schema.fbs +// Generator: backends/mlx/serialization/generate.py +// +// To regenerate, run from the executorch root: +// python backends/mlx/serialization/generate.py +// +// ============================================================================ +// -*- c++ -*- + +#include "MLXLoader.h" + +#include +#include + +namespace executorch { +namespace backends { +namespace mlx { +namespace loader { + +namespace { + +// Header structure for MLX payload +constexpr size_t kHeaderSize = 24; +constexpr uint32_t kMagic = 0x30584C4D; // "MLX0" in little-endian + +struct MLXHeader { + uint32_t padding; + uint32_t magic; + uint64_t data_offset; + uint64_t data_size; +}; +static_assert(sizeof(MLXHeader) == kHeaderSize, "MLXHeader size mismatch"); + +bool parse_header(const void* data, size_t size, MLXHeader& header) { + if (size < kHeaderSize) { + return false; + } + std::memcpy(&header, data, sizeof(MLXHeader)); + if (header.magic != kMagic) { + return false; + } + // Validate data_offset: must be strictly greater than kHeaderSize (so the + // FlatBuffer region is non-empty) and must not exceed the total buffer size. + if (header.data_offset <= kHeaderSize || header.data_offset > size) { + return false; + } + return true; +} + +// Helper to convert FlatBuffer vectors to std::vector. +// Caps size to prevent unbounded allocations from malformed payloads. +template +std::vector to_vector(const flatbuffers::Vector* fb_vec) { + if (!fb_vec) { + return {}; + } + constexpr size_t kMaxVectorSize = 1'000'000; + if (fb_vec->size() > kMaxVectorSize) { + throw std::runtime_error( + "FlatBuffer vector size " + std::to_string(fb_vec->size()) + + " exceeds maximum of " + std::to_string(kMaxVectorSize)); + } + return std::vector(fb_vec->begin(), fb_vec->end()); +} + +} // namespace + +// ============================================================================= +// load_instruction - AUTO-GENERATED switch statement +// ============================================================================= + +Instruction load_instruction(const mlx_delegate::Instruction* fb_instr) { + Instruction instr; + + if (!fb_instr || !fb_instr->op()) { + instr.op = OpCode::NOOP; + instr.node = NoopNode{}; + return instr; + } + + auto op_type = fb_instr->op_type(); + + switch (op_type) { + case mlx_delegate::OpNode_NoopNode: { + instr.op = OpCode::NOOP; + instr.node = NoopNode{}; + break; + } + + case mlx_delegate::OpNode_IdCopyNode: { + auto fb = fb_instr->op_as_IdCopyNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + IdCopyNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::ID_COPY; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_AddmmNode: { + auto fb = fb_instr->op_as_AddmmNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + AddmmNode node; + node.mat1 = convert_tid(fb->mat1()); + node.mat2 = convert_tid(fb->mat2()); + node.out = convert_tid(fb->out()); + if (fb->bias()) { + node.bias = convert_tid(fb->bias()); + } + node.alpha = fb->alpha(); + node.beta = fb->beta(); + instr.op = OpCode::ADDMM; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ItemIntNode: { + auto fb = fb_instr->op_as_ItemIntNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ItemIntNode node; + node.x = convert_tid(fb->x()); + node.out = convert_vid(fb->out()); + instr.op = OpCode::ITEM_INT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ExpandDimsNode: { + auto fb = fb_instr->op_as_ExpandDimsNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ExpandDimsNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axis = fb->axis(); + instr.op = OpCode::EXPAND_DIMS; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_TileNode: { + auto fb = fb_instr->op_as_TileNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + TileNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + if (fb->reps()) { + for (size_t i = 0; i < fb->reps()->size(); ++i) { + node.reps.push_back(convert_int_or_vid(fb->reps()->Get(static_cast(i)))); + } + } + instr.op = OpCode::TILE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_TakeAlongAxisNode: { + auto fb = fb_instr->op_as_TakeAlongAxisNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + TakeAlongAxisNode node; + node.x = convert_tid(fb->x()); + node.indices = convert_tid(fb->indices()); + node.out = convert_tid(fb->out()); + node.axis = fb->axis(); + instr.op = OpCode::TAKE_ALONG_AXIS; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_TakeNode: { + auto fb = fb_instr->op_as_TakeNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + TakeNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.index = convert_int_or_vid_or_tid(fb->index()); + node.axis = fb->axis(); + instr.op = OpCode::TAKE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_RMSNormNode: { + auto fb = fb_instr->op_as_RMSNormNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + RMSNormNode node; + node.x = convert_tid(fb->x()); + if (fb->weight()) { + node.weight = convert_tid(fb->weight()); + } + node.out = convert_tid(fb->out()); + node.eps = fb->eps(); + instr.op = OpCode::RMS_NORM; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_LayerNormNode: { + auto fb = fb_instr->op_as_LayerNormNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + LayerNormNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + if (fb->weight()) { + node.weight = convert_tid(fb->weight()); + } + if (fb->bias()) { + node.bias = convert_tid(fb->bias()); + } + node.eps = fb->eps(); + instr.op = OpCode::LAYER_NORM; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_RopeNode: { + auto fb = fb_instr->op_as_RopeNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + RopeNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.dims = fb->dims(); + node.offset = convert_vid_or_tid(fb->offset()); + if (fb->freqs()) { + node.freqs = convert_tid(fb->freqs()); + } + node.traditional = fb->traditional(); + node.base = fb->base(); + node.scale = fb->scale(); + instr.op = OpCode::ROPE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SdpaNode: { + auto fb = fb_instr->op_as_SdpaNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SdpaNode node; + node.q = convert_tid(fb->q()); + node.k = convert_tid(fb->k()); + node.v = convert_tid(fb->v()); + node.out = convert_tid(fb->out()); + node.scale = fb->scale(); + if (fb->mask()) { + node.mask = convert_tid(fb->mask()); + } + node.causal = fb->causal(); + instr.op = OpCode::SDPA; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_AddNode: { + auto fb = fb_instr->op_as_AddNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + AddNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::ADD; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_AddIntNode: { + auto fb = fb_instr->op_as_AddIntNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + AddIntNode node; + node.a = convert_int_or_vid(fb->a()); + node.b = convert_int_or_vid(fb->b()); + node.out = convert_vid(fb->out()); + instr.op = OpCode::ADD_INT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SubtractIntNode: { + auto fb = fb_instr->op_as_SubtractIntNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SubtractIntNode node; + node.a = convert_int_or_vid(fb->a()); + node.b = convert_int_or_vid(fb->b()); + node.out = convert_vid(fb->out()); + instr.op = OpCode::SUBTRACT_INT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_MultiplyIntNode: { + auto fb = fb_instr->op_as_MultiplyIntNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + MultiplyIntNode node; + node.a = convert_int_or_vid(fb->a()); + node.b = convert_int_or_vid(fb->b()); + node.out = convert_vid(fb->out()); + instr.op = OpCode::MULTIPLY_INT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_FloorDivideIntNode: { + auto fb = fb_instr->op_as_FloorDivideIntNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + FloorDivideIntNode node; + node.a = convert_int_or_vid(fb->a()); + node.b = convert_int_or_vid(fb->b()); + node.out = convert_vid(fb->out()); + instr.op = OpCode::FLOOR_DIVIDE_INT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ModIntNode: { + auto fb = fb_instr->op_as_ModIntNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ModIntNode node; + node.a = convert_int_or_vid(fb->a()); + node.b = convert_int_or_vid(fb->b()); + node.out = convert_vid(fb->out()); + instr.op = OpCode::MOD_INT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SymSizeNode: { + auto fb = fb_instr->op_as_SymSizeNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SymSizeNode node; + node.a = convert_tid(fb->a()); + node.dim = fb->dim(); + node.out = convert_vid(fb->out()); + instr.op = OpCode::SYM_SIZE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_MultiplyNode: { + auto fb = fb_instr->op_as_MultiplyNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + MultiplyNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::MULTIPLY; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_DivideNode: { + auto fb = fb_instr->op_as_DivideNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + DivideNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::DIVIDE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SubtractNode: { + auto fb = fb_instr->op_as_SubtractNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SubtractNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::SUBTRACT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_Conv1DNode: { + auto fb = fb_instr->op_as_Conv1DNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + Conv1DNode node; + node.x = convert_tid(fb->x()); + node.w = convert_tid(fb->w()); + node.out = convert_tid(fb->out()); + node.stride = fb->stride(); + node.padding = fb->padding(); + node.dilation = fb->dilation(); + node.groups = fb->groups(); + instr.op = OpCode::CONV1D; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_Conv2DNode: { + auto fb = fb_instr->op_as_Conv2DNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + Conv2DNode node; + node.x = convert_tid(fb->x()); + node.w = convert_tid(fb->w()); + node.out = convert_tid(fb->out()); + node.stride_h = fb->stride_h(); + node.stride_w = fb->stride_w(); + node.padding_h = fb->padding_h(); + node.padding_w = fb->padding_w(); + node.dilation_h = fb->dilation_h(); + node.dilation_w = fb->dilation_w(); + node.groups = fb->groups(); + instr.op = OpCode::CONV2D; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_Conv3DNode: { + auto fb = fb_instr->op_as_Conv3DNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + Conv3DNode node; + node.x = convert_tid(fb->x()); + node.w = convert_tid(fb->w()); + node.out = convert_tid(fb->out()); + node.stride_d = fb->stride_d(); + node.stride_h = fb->stride_h(); + node.stride_w = fb->stride_w(); + node.padding_d = fb->padding_d(); + node.padding_h = fb->padding_h(); + node.padding_w = fb->padding_w(); + node.dilation_d = fb->dilation_d(); + node.dilation_h = fb->dilation_h(); + node.dilation_w = fb->dilation_w(); + node.groups = fb->groups(); + instr.op = OpCode::CONV3D; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ConvTranspose1DNode: { + auto fb = fb_instr->op_as_ConvTranspose1DNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ConvTranspose1DNode node; + node.x = convert_tid(fb->x()); + node.w = convert_tid(fb->w()); + node.out = convert_tid(fb->out()); + node.stride = fb->stride(); + node.padding = fb->padding(); + node.dilation = fb->dilation(); + node.output_padding = fb->output_padding(); + node.groups = fb->groups(); + instr.op = OpCode::CONV_TRANSPOSE1D; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ConvTranspose2DNode: { + auto fb = fb_instr->op_as_ConvTranspose2DNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ConvTranspose2DNode node; + node.x = convert_tid(fb->x()); + node.w = convert_tid(fb->w()); + node.out = convert_tid(fb->out()); + node.stride_h = fb->stride_h(); + node.stride_w = fb->stride_w(); + node.padding_h = fb->padding_h(); + node.padding_w = fb->padding_w(); + node.dilation_h = fb->dilation_h(); + node.dilation_w = fb->dilation_w(); + node.output_padding_h = fb->output_padding_h(); + node.output_padding_w = fb->output_padding_w(); + node.groups = fb->groups(); + instr.op = OpCode::CONV_TRANSPOSE2D; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ConvTranspose3DNode: { + auto fb = fb_instr->op_as_ConvTranspose3DNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ConvTranspose3DNode node; + node.x = convert_tid(fb->x()); + node.w = convert_tid(fb->w()); + node.out = convert_tid(fb->out()); + node.stride_d = fb->stride_d(); + node.stride_h = fb->stride_h(); + node.stride_w = fb->stride_w(); + node.padding_d = fb->padding_d(); + node.padding_h = fb->padding_h(); + node.padding_w = fb->padding_w(); + node.dilation_d = fb->dilation_d(); + node.dilation_h = fb->dilation_h(); + node.dilation_w = fb->dilation_w(); + node.output_padding_d = fb->output_padding_d(); + node.output_padding_h = fb->output_padding_h(); + node.output_padding_w = fb->output_padding_w(); + node.groups = fb->groups(); + instr.op = OpCode::CONV_TRANSPOSE3D; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_GeluNode: { + auto fb = fb_instr->op_as_GeluNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + GeluNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.approximate = fb->approximate() ? fb->approximate()->str() : ""; + instr.op = OpCode::GELU; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ARangeNode: { + auto fb = fb_instr->op_as_ARangeNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ARangeNode node; + node.out = convert_tid(fb->out()); + node.start = convert_int_or_vid(fb->start()); + node.stop = convert_int_or_vid(fb->stop()); + node.step = convert_int_or_vid(fb->step()); + auto scalar_type_opt = fb->scalar_type(); + if (scalar_type_opt.has_value()) { + node.scalar_type = scalar_type_opt.value(); + } + instr.op = OpCode::ARANGE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SiluNode: { + auto fb = fb_instr->op_as_SiluNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SiluNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::SILU; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SigmoidNode: { + auto fb = fb_instr->op_as_SigmoidNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SigmoidNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::SIGMOID; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_TanhNode: { + auto fb = fb_instr->op_as_TanhNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + TanhNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::TANH; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SqueezeNode: { + auto fb = fb_instr->op_as_SqueezeNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SqueezeNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.dims = to_vector(fb->dims()); + instr.op = OpCode::SQUEEZE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SplitNode: { + auto fb = fb_instr->op_as_SplitNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SplitNode node; + node.x = convert_tid(fb->x()); + if (fb->outs()) { + for (auto fb_tid : *fb->outs()) { + node.outs.push_back(convert_tid(fb_tid)); + } + } + if (fb->sizes()) { + for (size_t i = 0; i < fb->sizes()->size(); ++i) { + node.sizes.push_back(convert_int_or_vid(fb->sizes()->Get(static_cast(i)))); + } + } + node.axis = fb->axis(); + instr.op = OpCode::SPLIT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_RsqrtNode: { + auto fb = fb_instr->op_as_RsqrtNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + RsqrtNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::RSQRT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_MaximumNode: { + auto fb = fb_instr->op_as_MaximumNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + MaximumNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::MAXIMUM; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_MinimumNode: { + auto fb = fb_instr->op_as_MinimumNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + MinimumNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::MINIMUM; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_LogNode: { + auto fb = fb_instr->op_as_LogNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + LogNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::LOG; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SoftmaxNode: { + auto fb = fb_instr->op_as_SoftmaxNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SoftmaxNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axis = fb->axis(); + node.precise = fb->precise(); + instr.op = OpCode::SOFTMAX; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_BroadcastToNode: { + auto fb = fb_instr->op_as_BroadcastToNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + BroadcastToNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + if (fb->shape()) { + for (size_t i = 0; i < fb->shape()->size(); ++i) { + node.shape.push_back(convert_int_or_vid(fb->shape()->Get(static_cast(i)))); + } + } + instr.op = OpCode::BROADCAST_TO; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_PadNode: { + auto fb = fb_instr->op_as_PadNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + PadNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + if (fb->pad_width()) { + for (size_t i = 0; i < fb->pad_width()->size(); ++i) { + node.pad_width.push_back(convert_int_or_vid(fb->pad_width()->Get(static_cast(i)))); + } + } + node.mode = fb->mode() ? fb->mode()->str() : ""; + node.constant_value = fb->constant_value(); + instr.op = OpCode::PAD; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_WhereNode: { + auto fb = fb_instr->op_as_WhereNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + WhereNode node; + node.condition = convert_tid(fb->condition()); + node.x = convert_tid(fb->x()); + node.y = convert_tid(fb->y()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::WHERE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ReshapeNode: { + auto fb = fb_instr->op_as_ReshapeNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ReshapeNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + if (fb->shape()) { + for (size_t i = 0; i < fb->shape()->size(); ++i) { + node.shape.push_back(convert_int_or_vid(fb->shape()->Get(static_cast(i)))); + } + } + instr.op = OpCode::RESHAPE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_TransposeNode: { + auto fb = fb_instr->op_as_TransposeNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + TransposeNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.perm = to_vector(fb->perm()); + instr.op = OpCode::TRANSPOSE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_AsStridedNode: { + auto fb = fb_instr->op_as_AsStridedNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + AsStridedNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + if (fb->shape()) { + for (size_t i = 0; i < fb->shape()->size(); ++i) { + node.shape.push_back(convert_int_or_vid(fb->shape()->Get(static_cast(i)))); + } + } + if (fb->strides()) { + for (size_t i = 0; i < fb->strides()->size(); ++i) { + node.strides.push_back(convert_int_or_vid(fb->strides()->Get(static_cast(i)))); + } + } + node.offset = fb->offset(); + instr.op = OpCode::AS_STRIDED; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ContiguousNode: { + auto fb = fb_instr->op_as_ContiguousNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ContiguousNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::CONTIGUOUS; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_GatherNode: { + auto fb = fb_instr->op_as_GatherNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + GatherNode node; + node.x = convert_tid(fb->x()); + if (fb->indices()) { + for (auto fb_tid : *fb->indices()) { + node.indices.push_back(convert_tid(fb_tid)); + } + } + node.out = convert_tid(fb->out()); + node.axes = to_vector(fb->axes()); + node.slice_sizes = to_vector(fb->slice_sizes()); + instr.op = OpCode::GATHER; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SliceNode: { + auto fb = fb_instr->op_as_SliceNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SliceNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axis = convert_int_or_vid(fb->axis()); + node.start = convert_int_or_vid(fb->start()); + node.stop = convert_int_or_vid(fb->stop()); + node.step = fb->step(); + instr.op = OpCode::SLICE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_AsTypeNode: { + auto fb = fb_instr->op_as_AsTypeNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + AsTypeNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.scalar_type = fb->scalar_type(); + instr.op = OpCode::ASTYPE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_QuantizedMatmulNode: { + auto fb = fb_instr->op_as_QuantizedMatmulNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + QuantizedMatmulNode node; + node.x = convert_tid(fb->x()); + node.w = convert_tid(fb->w()); + node.scales = convert_tid(fb->scales()); + node.out = convert_tid(fb->out()); + if (fb->biases()) { + node.biases = convert_tid(fb->biases()); + } + node.group_size = fb->group_size(); + node.bits = fb->bits(); + node.mode = fb->mode() ? fb->mode()->str() : ""; + node.transpose = fb->transpose(); + instr.op = OpCode::QUANTIZED_MATMUL; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ScatterAddNode: { + auto fb = fb_instr->op_as_ScatterAddNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ScatterAddNode node; + node.x = convert_tid(fb->x()); + node.indices = convert_tid(fb->indices()); + node.updates = convert_tid(fb->updates()); + node.out = convert_tid(fb->out()); + node.axis = fb->axis(); + instr.op = OpCode::SCATTER_ADD; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ConcatenateNode: { + auto fb = fb_instr->op_as_ConcatenateNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ConcatenateNode node; + if (fb->tensors()) { + for (auto fb_tid : *fb->tensors()) { + node.tensors.push_back(convert_tid(fb_tid)); + } + } + node.out = convert_tid(fb->out()); + node.axis = fb->axis(); + instr.op = OpCode::CONCATENATE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_FullNode: { + auto fb = fb_instr->op_as_FullNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + FullNode node; + node.out = convert_tid(fb->out()); + if (fb->shape()) { + for (size_t i = 0; i < fb->shape()->size(); ++i) { + node.shape.push_back(convert_int_or_vid(fb->shape()->Get(static_cast(i)))); + } + } + node.v = convert_float_or_vid(fb->v()); + node.scalar_type = fb->scalar_type(); + instr.op = OpCode::FULL; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_FullLikeNode: { + auto fb = fb_instr->op_as_FullLikeNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + FullLikeNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.v = convert_float_or_vid(fb->v()); + auto scalar_type_opt = fb->scalar_type(); + if (scalar_type_opt.has_value()) { + node.scalar_type = scalar_type_opt.value(); + } + instr.op = OpCode::FULL_LIKE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ArgmaxNode: { + auto fb = fb_instr->op_as_ArgmaxNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ArgmaxNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axis = fb->axis(); + node.keepdims = fb->keepdims(); + instr.op = OpCode::ARGMAX; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SliceUpdateNode: { + auto fb = fb_instr->op_as_SliceUpdateNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SliceUpdateNode node; + node.dst = convert_tid(fb->dst()); + node.update = convert_tid(fb->update()); + node.out = convert_tid(fb->out()); + node.axis = convert_int_or_vid(fb->axis()); + node.start = convert_int_or_vid(fb->start()); + node.stop = convert_int_or_vid(fb->stop()); + node.step = fb->step(); + instr.op = OpCode::SLICE_UPDATE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_IndexCopyNode: { + auto fb = fb_instr->op_as_IndexCopyNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + IndexCopyNode node; + node.dst = convert_tid(fb->dst()); + node.update = convert_tid(fb->update()); + node.indices = convert_tid(fb->indices()); + node.out = convert_tid(fb->out()); + node.axis = fb->axis(); + instr.op = OpCode::INDEX_COPY; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_DequantizeNode: { + auto fb = fb_instr->op_as_DequantizeNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + DequantizeNode node; + node.w = convert_tid(fb->w()); + node.scales = convert_tid(fb->scales()); + node.out = convert_tid(fb->out()); + if (fb->biases()) { + node.biases = convert_tid(fb->biases()); + } + node.group_size = fb->group_size(); + node.bits = fb->bits(); + node.mode = fb->mode() ? fb->mode()->str() : ""; + if (fb->global_scale()) { + node.global_scale = convert_tid(fb->global_scale()); + } + auto dtype_opt = fb->dtype(); + if (dtype_opt.has_value()) { + node.dtype = dtype_opt.value(); + } + instr.op = OpCode::DEQUANTIZE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_LessNode: { + auto fb = fb_instr->op_as_LessNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + LessNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::LESS; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_LessEqualNode: { + auto fb = fb_instr->op_as_LessEqualNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + LessEqualNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::LESS_EQUAL; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_GreaterNode: { + auto fb = fb_instr->op_as_GreaterNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + GreaterNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::GREATER; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_GreaterEqualNode: { + auto fb = fb_instr->op_as_GreaterEqualNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + GreaterEqualNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::GREATER_EQUAL; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_EqualNode: { + auto fb = fb_instr->op_as_EqualNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + EqualNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::EQUAL; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_NotEqualNode: { + auto fb = fb_instr->op_as_NotEqualNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + NotEqualNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::NOT_EQUAL; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_LogicalNotNode: { + auto fb = fb_instr->op_as_LogicalNotNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + LogicalNotNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::LOGICAL_NOT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_LogicalAndNode: { + auto fb = fb_instr->op_as_LogicalAndNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + LogicalAndNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::LOGICAL_AND; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_LogicalOrNode: { + auto fb = fb_instr->op_as_LogicalOrNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + LogicalOrNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::LOGICAL_OR; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_TriNode: { + auto fb = fb_instr->op_as_TriNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + TriNode node; + node.out = convert_tid(fb->out()); + node.n = convert_int_or_vid(fb->n()); + node.m = convert_int_or_vid(fb->m()); + node.k = fb->k(); + node.scalar_type = fb->scalar_type(); + instr.op = OpCode::TRI; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_TrilNode: { + auto fb = fb_instr->op_as_TrilNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + TrilNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.k = fb->k(); + instr.op = OpCode::TRIL; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_TriuNode: { + auto fb = fb_instr->op_as_TriuNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + TriuNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.k = fb->k(); + instr.op = OpCode::TRIU; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ClipNode: { + auto fb = fb_instr->op_as_ClipNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ClipNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + if (fb->a_min()) { + node.a_min = convert_tid(fb->a_min()); + } + if (fb->a_max()) { + node.a_max = convert_tid(fb->a_max()); + } + instr.op = OpCode::CLIP; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_CumsumNode: { + auto fb = fb_instr->op_as_CumsumNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + CumsumNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axis = fb->axis(); + node.reverse = fb->reverse(); + node.inclusive = fb->inclusive(); + instr.op = OpCode::CUMSUM; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_StackNode: { + auto fb = fb_instr->op_as_StackNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + StackNode node; + if (fb->tensors()) { + for (auto fb_tid : *fb->tensors()) { + node.tensors.push_back(convert_tid(fb_tid)); + } + } + node.out = convert_tid(fb->out()); + node.axis = fb->axis(); + instr.op = OpCode::STACK; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SignNode: { + auto fb = fb_instr->op_as_SignNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SignNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::SIGN; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_AnyNode: { + auto fb = fb_instr->op_as_AnyNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + AnyNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axes = to_vector(fb->axes()); + node.keepdims = fb->keepdims(); + instr.op = OpCode::ANY; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_AllNode: { + auto fb = fb_instr->op_as_AllNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + AllNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axes = to_vector(fb->axes()); + node.keepdims = fb->keepdims(); + instr.op = OpCode::ALL; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_RepeatNode: { + auto fb = fb_instr->op_as_RepeatNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + RepeatNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.repeats = convert_int_or_vid(fb->repeats()); + node.axis = fb->axis(); + instr.op = OpCode::REPEAT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SortNode: { + auto fb = fb_instr->op_as_SortNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SortNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axis = fb->axis(); + instr.op = OpCode::SORT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ArgsortNode: { + auto fb = fb_instr->op_as_ArgsortNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ArgsortNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axis = fb->axis(); + instr.op = OpCode::ARGSORT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_PartitionNode: { + auto fb = fb_instr->op_as_PartitionNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + PartitionNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.kth = convert_int_or_vid(fb->kth()); + node.axis = fb->axis(); + instr.op = OpCode::PARTITION; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ArgPartitionNode: { + auto fb = fb_instr->op_as_ArgPartitionNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ArgPartitionNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.kth = convert_int_or_vid(fb->kth()); + node.axis = fb->axis(); + instr.op = OpCode::ARG_PARTITION; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_FloorNode: { + auto fb = fb_instr->op_as_FloorNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + FloorNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::FLOOR; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_CeilNode: { + auto fb = fb_instr->op_as_CeilNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + CeilNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::CEIL; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SquareNode: { + auto fb = fb_instr->op_as_SquareNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SquareNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::SQUARE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ExpNode: { + auto fb = fb_instr->op_as_ExpNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ExpNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::EXP; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SinNode: { + auto fb = fb_instr->op_as_SinNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SinNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::SIN; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_CosNode: { + auto fb = fb_instr->op_as_CosNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + CosNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::COS; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_TanNode: { + auto fb = fb_instr->op_as_TanNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + TanNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::TAN; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ArcsinNode: { + auto fb = fb_instr->op_as_ArcsinNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ArcsinNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::ARCSIN; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ArccosNode: { + auto fb = fb_instr->op_as_ArccosNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ArccosNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::ARCCOS; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ArctanNode: { + auto fb = fb_instr->op_as_ArctanNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ArctanNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::ARCTAN; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SinhNode: { + auto fb = fb_instr->op_as_SinhNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SinhNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::SINH; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_CoshNode: { + auto fb = fb_instr->op_as_CoshNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + CoshNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::COSH; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ArcsinhNode: { + auto fb = fb_instr->op_as_ArcsinhNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ArcsinhNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::ARCSINH; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ArccoshNode: { + auto fb = fb_instr->op_as_ArccoshNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ArccoshNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::ARCCOSH; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ArctanhNode: { + auto fb = fb_instr->op_as_ArctanhNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ArctanhNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::ARCTANH; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_Log2Node: { + auto fb = fb_instr->op_as_Log2Node(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + Log2Node node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::LOG2; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_Log10Node: { + auto fb = fb_instr->op_as_Log10Node(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + Log10Node node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::LOG10; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_Log1pNode: { + auto fb = fb_instr->op_as_Log1pNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + Log1pNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::LOG1P; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ErfNode: { + auto fb = fb_instr->op_as_ErfNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ErfNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::ERF; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_Expm1Node: { + auto fb = fb_instr->op_as_Expm1Node(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + Expm1Node node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::EXPM1; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_RoundNode: { + auto fb = fb_instr->op_as_RoundNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + RoundNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.decimals = fb->decimals(); + instr.op = OpCode::ROUND; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ReciprocalNode: { + auto fb = fb_instr->op_as_ReciprocalNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ReciprocalNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::RECIPROCAL; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SqrtNode: { + auto fb = fb_instr->op_as_SqrtNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SqrtNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::SQRT; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_AbsNode: { + auto fb = fb_instr->op_as_AbsNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + AbsNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::ABS; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_NegNode: { + auto fb = fb_instr->op_as_NegNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + NegNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::NEG; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_Atan2Node: { + auto fb = fb_instr->op_as_Atan2Node(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + Atan2Node node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::ATAN2; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_LogAddExpNode: { + auto fb = fb_instr->op_as_LogAddExpNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + LogAddExpNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::LOG_ADD_EXP; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_FloorDivideNode: { + auto fb = fb_instr->op_as_FloorDivideNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + FloorDivideNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::FLOOR_DIVIDE; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_RemainderNode: { + auto fb = fb_instr->op_as_RemainderNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + RemainderNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::REMAINDER; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_PowerNode: { + auto fb = fb_instr->op_as_PowerNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + PowerNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + instr.op = OpCode::POWER; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_LogSumExpNode: { + auto fb = fb_instr->op_as_LogSumExpNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + LogSumExpNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axes = to_vector(fb->axes()); + node.keepdims = fb->keepdims(); + instr.op = OpCode::LOG_SUM_EXP; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_SumNode: { + auto fb = fb_instr->op_as_SumNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + SumNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axes = to_vector(fb->axes()); + node.keepdims = fb->keepdims(); + instr.op = OpCode::SUM; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_MeanNode: { + auto fb = fb_instr->op_as_MeanNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + MeanNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axes = to_vector(fb->axes()); + node.keepdims = fb->keepdims(); + instr.op = OpCode::MEAN; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_VarNode: { + auto fb = fb_instr->op_as_VarNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + VarNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axes = to_vector(fb->axes()); + node.keepdims = fb->keepdims(); + node.ddof = fb->ddof(); + instr.op = OpCode::VAR; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_StdNode: { + auto fb = fb_instr->op_as_StdNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + StdNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axes = to_vector(fb->axes()); + node.keepdims = fb->keepdims(); + node.ddof = fb->ddof(); + instr.op = OpCode::STD; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ProdNode: { + auto fb = fb_instr->op_as_ProdNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ProdNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axes = to_vector(fb->axes()); + node.keepdims = fb->keepdims(); + instr.op = OpCode::PROD; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_MaxNode: { + auto fb = fb_instr->op_as_MaxNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + MaxNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axes = to_vector(fb->axes()); + node.keepdims = fb->keepdims(); + instr.op = OpCode::MAX; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_MinNode: { + auto fb = fb_instr->op_as_MinNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + MinNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axes = to_vector(fb->axes()); + node.keepdims = fb->keepdims(); + instr.op = OpCode::MIN; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ArgminNode: { + auto fb = fb_instr->op_as_ArgminNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ArgminNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axis = fb->axis(); + node.keepdims = fb->keepdims(); + instr.op = OpCode::ARGMIN; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_MedianNode: { + auto fb = fb_instr->op_as_MedianNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + MedianNode node; + node.x = convert_tid(fb->x()); + node.out = convert_tid(fb->out()); + node.axes = to_vector(fb->axes()); + node.keepdims = fb->keepdims(); + instr.op = OpCode::MEDIAN; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_GatherMmNode: { + auto fb = fb_instr->op_as_GatherMmNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + GatherMmNode node; + node.a = convert_tid(fb->a()); + node.b = convert_tid(fb->b()); + node.out = convert_tid(fb->out()); + if (fb->lhs_indices()) { + node.lhs_indices = convert_tid(fb->lhs_indices()); + } + if (fb->rhs_indices()) { + node.rhs_indices = convert_tid(fb->rhs_indices()); + } + node.sorted_indices = fb->sorted_indices(); + instr.op = OpCode::GATHER_MM; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_GatherQmmNode: { + auto fb = fb_instr->op_as_GatherQmmNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + GatherQmmNode node; + node.x = convert_tid(fb->x()); + node.w = convert_tid(fb->w()); + node.scales = convert_tid(fb->scales()); + node.out = convert_tid(fb->out()); + node.mode = fb->mode() ? fb->mode()->str() : ""; + if (fb->biases()) { + node.biases = convert_tid(fb->biases()); + } + if (fb->lhs_indices()) { + node.lhs_indices = convert_tid(fb->lhs_indices()); + } + if (fb->rhs_indices()) { + node.rhs_indices = convert_tid(fb->rhs_indices()); + } + node.transpose = fb->transpose(); + node.group_size = fb->group_size(); + node.bits = fb->bits(); + node.sorted_indices = fb->sorted_indices(); + instr.op = OpCode::GATHER_QMM; + instr.node = std::move(node); + break; + } + + case mlx_delegate::OpNode_ScanNode: { + auto fb = fb_instr->op_as_ScanNode(); + if (!fb) {{ + throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}"); + }} + ScanNode node; + if (fb->originals()) { + for (auto fb_tid : *fb->originals()) { + node.originals.push_back(convert_tid(fb_tid)); + } + } + if (fb->sliced()) { + for (auto fb_tid : *fb->sliced()) { + node.sliced.push_back(convert_tid(fb_tid)); + } + } + if (fb->outputs()) { + for (auto fb_tid : *fb->outputs()) { + node.outputs.push_back(convert_tid(fb_tid)); + } + } + if (fb->carry()) { + for (auto fb_tid : *fb->carry()) { + node.carry.push_back(convert_tid(fb_tid)); + } + } + node.body_chain_idx = fb->body_chain_idx(); + node.scan_axis = fb->scan_axis(); + instr.op = OpCode::SCAN; + instr.node = std::move(node); + break; + } + + default: + throw std::runtime_error( + "Unknown op_type in load_instruction: " + + std::to_string(static_cast(op_type)) + + ". The .pte was built with a newer schema than this binary. " + "Rebuild with the latest runtime."); + } + + return instr; +} + +// ============================================================================= +// load_program +// ============================================================================= + +MLXProgram load_program(const void* data, size_t size) { + MLXHeader header; + if (!parse_header(data, size, header)) { + throw std::runtime_error("Invalid MLX header"); + } + + // Defense-in-depth: parse_header already validates this, but guard the + // unsigned subtraction against underflow in case the call site ever changes. + if (header.data_offset <= kHeaderSize || header.data_offset > size) { + throw std::runtime_error("data_offset out of range"); + } + const uint8_t* fb_data = static_cast(data) + kHeaderSize; + size_t fb_size = header.data_offset - kHeaderSize; + + flatbuffers::Verifier verifier(fb_data, fb_size); + if (!mlx_delegate::VerifyMLXGraphBuffer(verifier)) { + throw std::runtime_error("Invalid FlatBuffer data"); + } + + const auto* fb_graph = mlx_delegate::GetMLXGraph(fb_data); + if (!fb_graph) { + throw std::runtime_error("Failed to parse MLXGraph"); + } + + MLXProgram program; + + if (fb_graph->version()) { + program.version = fb_graph->version()->str(); + } + + program.num_constant_tensors = fb_graph->num_constant_tensors(); + program.num_input_tensors = fb_graph->num_input_tensors(); + program.num_output_tensors = fb_graph->num_output_tensors(); + program.num_mutable_buffer_tensors = fb_graph->num_mutable_buffer_tensors(); + program.num_temp_tensors = fb_graph->num_temp_tensors(); + program.num_values = fb_graph->num_values(); + + // Cap all counts/collection sizes to prevent unbounded allocations from + // malformed FlatBuffer payloads + constexpr size_t kMaxCollectionSize = 1'000'000; + auto check_collection_size = [](size_t sz, const char* name) { + if (sz > kMaxCollectionSize) { + throw std::runtime_error( + std::string("Malformed program: ") + name + " size " + + std::to_string(sz) + " exceeds maximum of " + + std::to_string(kMaxCollectionSize)); + } + }; + + check_collection_size(program.num_tensors(), "num_tensors()"); + check_collection_size(program.num_values, "num_values"); + + if (fb_graph->instruction_chains()) { + check_collection_size(fb_graph->instruction_chains()->size(), "instruction_chains"); + program.instruction_chains.reserve(fb_graph->instruction_chains()->size()); + for (size_t c = 0; c < fb_graph->instruction_chains()->size(); ++c) { + const auto* fb_chain = fb_graph->instruction_chains()->Get(static_cast(c)); + std::vector chain; + if (fb_chain && fb_chain->instructions()) { + check_collection_size(fb_chain->instructions()->size(), "instructions in chain"); + chain.reserve(fb_chain->instructions()->size()); + for (size_t i = 0; i < fb_chain->instructions()->size(); ++i) { + chain.push_back(load_instruction(fb_chain->instructions()->Get(static_cast(i)))); + } + } + program.instruction_chains.push_back(std::move(chain)); + } + } + + program.main_chain_idx = fb_graph->main_chain_idx(); + program.init_chain_idx = fb_graph->init_chain_idx(); + + // Validate chain indices against actual instruction_chains size. + if (program.main_chain_idx >= program.instruction_chains.size()) { + throw std::runtime_error( + "Invalid main_chain_idx " + + std::to_string(program.main_chain_idx) + + " (only " + std::to_string(program.instruction_chains.size()) + + " chains loaded)"); + } + if (program.init_chain_idx >= 0 && + static_cast(program.init_chain_idx) >= + program.instruction_chains.size()) { + throw std::runtime_error( + "Invalid init_chain_idx " + + std::to_string(program.init_chain_idx) + + " (only " + std::to_string(program.instruction_chains.size()) + + " chains loaded)"); + } + + if (fb_graph->input_map()) { + check_collection_size(fb_graph->input_map()->size(), "input_map"); + for (size_t i = 0; i < fb_graph->input_map()->size(); ++i) { + const auto* slot = fb_graph->input_map()->Get(static_cast(i)); + auto sv = convert_slot_variant(slot); + if (sv.slot_type == SlotType::TensorSlot && + sv.idx >= program.num_tensors()) { + throw std::runtime_error( + "input_map: slot index " + std::to_string(sv.idx) + + " exceeds num_tensors " + + std::to_string(program.num_tensors())); + } + program.input_map.push_back(sv); + } + } + + if (fb_graph->output_map()) { + check_collection_size(fb_graph->output_map()->size(), "output_map"); + for (size_t i = 0; i < fb_graph->output_map()->size(); ++i) { + const auto* slot = fb_graph->output_map()->Get(static_cast(i)); + auto sv = convert_slot_variant(slot); + if (sv.slot_type == SlotType::TensorSlot && + sv.idx >= program.num_tensors()) { + throw std::runtime_error( + "output_map: slot index " + std::to_string(sv.idx) + + " exceeds num_tensors " + + std::to_string(program.num_tensors())); + } + program.output_map.push_back(sv); + } + } + + if (fb_graph->mutable_buffer_map()) { + check_collection_size(fb_graph->mutable_buffer_map()->size(), "mutable_buffer_map"); + for (size_t i = 0; i < fb_graph->mutable_buffer_map()->size(); ++i) { + const auto* slot = fb_graph->mutable_buffer_map()->Get(static_cast(i)); + auto sv = convert_slot_variant(slot); + if (sv.slot_type == SlotType::TensorSlot && + sv.idx >= program.num_tensors()) { + throw std::runtime_error( + "mutable_buffer_map: slot index " + std::to_string(sv.idx) + + " exceeds num_tensors " + + std::to_string(program.num_tensors())); + } + program.mutable_buffer_map.push_back(sv); + } + } + + if (fb_graph->named_slots()) { + check_collection_size(fb_graph->named_slots()->size(), "named_slots"); + for (size_t i = 0; i < fb_graph->named_slots()->size(); ++i) { + const auto* fb_slot = fb_graph->named_slots()->Get(static_cast(i)); + if (!fb_slot || !fb_slot->name()) { + throw std::runtime_error( + "Malformed program: named_slot at index " + std::to_string(i) + + " is null or has null name"); + } + NamedSlot slot; + slot.name = fb_slot->name()->str(); + slot.slot = convert_slot_variant(fb_slot->slot()); + program.named_slots.push_back(std::move(slot)); + } + } + + if (fb_graph->tensor_meta()) { + check_collection_size(fb_graph->tensor_meta()->size(), "tensor_meta"); + for (size_t i = 0; i < fb_graph->tensor_meta()->size(); ++i) { + const auto* fb_meta = fb_graph->tensor_meta()->Get(static_cast(i)); + if (fb_meta) { + TensorMeta meta; + if (fb_meta->shape()) { + // Validate tensor rank against kTensorDimensionLimit to prevent + // stack overflows from unchecked rank + constexpr size_t kTensorDimensionLimit = 16; + if (fb_meta->shape()->size() > kTensorDimensionLimit) { + throw std::runtime_error( + "Tensor at index " + std::to_string(i) + + " has rank " + std::to_string(fb_meta->shape()->size()) + + " exceeding kTensorDimensionLimit (" + + std::to_string(kTensorDimensionLimit) + ")"); + } + for (size_t j = 0; j < fb_meta->shape()->size(); ++j) { + const auto* fb_dim = fb_meta->shape()->Get(static_cast(j)); + if (!fb_dim) { + throw std::runtime_error( + "Null ShapeDim at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + ShapeDim dim; + dim.value = fb_dim->value(); + dim.min_value = fb_dim->min_value(); + dim.max_value = fb_dim->max_value(); + if (dim.value < -1) { + throw std::runtime_error( + "Invalid ShapeDim value " + std::to_string(dim.value) + + " at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + if (dim.is_dynamic()) { + if (dim.min_value < 0) { + throw std::runtime_error( + "Invalid ShapeDim min_value " + std::to_string(dim.min_value) + + " at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + if (dim.max_value != -1 && dim.max_value < dim.min_value) { + throw std::runtime_error( + "ShapeDim max_value " + std::to_string(dim.max_value) + + " < min_value " + std::to_string(dim.min_value) + + " at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + } + meta.shape.push_back(dim); + } + } + auto raw_scalar_type = fb_meta->scalar_type(); + if (raw_scalar_type < 0 || + raw_scalar_type >= + static_cast(ScalarType::NumOptions)) { + throw std::runtime_error( + "Invalid scalar_type " + std::to_string(raw_scalar_type) + + " in tensor_meta at index " + std::to_string(i)); + } + meta.scalar_type = static_cast(raw_scalar_type); + if (fb_meta->dim_order()) { + meta.dim_order = to_vector(fb_meta->dim_order()); + } + program.tensor_meta.push_back(std::move(meta)); + } else { + program.tensor_meta.push_back(std::nullopt); + } + } + } + + return program; +} + +} // namespace loader +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/runtime/MLXLoader.h b/backends/mlx/runtime/MLXLoader.h new file mode 100644 index 00000000000..b1eb34cb83f --- /dev/null +++ b/backends/mlx/runtime/MLXLoader.h @@ -0,0 +1,1712 @@ +// +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// ============================================================================ +// AUTO-GENERATED FILE - DO NOT EDIT MANUALLY +// ============================================================================ +// +// This file was generated from schema.fbs by the MLX delegate code generator. +// +// Source: backends/mlx/serialization/schema.fbs +// Generator: backends/mlx/serialization/generate.py +// +// To regenerate, run from the executorch root: +// python backends/mlx/serialization/generate.py +// +// ============================================================================ +// +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "schema_generated.h" + +// ExecuTorch scalar type for dtype representation +#include + +namespace executorch { +namespace backends { +namespace mlx { + +// ============================================================================= +// Core types matching the Python side +// ============================================================================= + +struct Tid { + uint32_t idx{}; +}; + +struct Vid { + uint32_t idx{}; +}; + +// ============================================================================= +// Tensor metadata +// ============================================================================= + +// Import ScalarType from ExecuTorch +using ScalarType = ::executorch::runtime::etensor::ScalarType; + +struct ShapeDim { + int32_t value{-1}; // Static dim (>= 0), or -1 for dynamic + int32_t min_value{0}; // Lower bound (when value == -1) + int32_t max_value{-1}; // Upper bound (-1 = unbounded, when value == -1) + + bool is_dynamic() const { return value < 0; } +}; + +struct TensorMeta { + std::vector shape; + ScalarType scalar_type{ScalarType::Float}; // ET ScalarType + std::vector dim_order; +}; + +// VidOrTid: either a scalar value (Vid) or a tensor (Tid) +struct VidOrTid { + Vid vid{}; + Tid tid{}; + bool is_vid{false}; // false = use tid, true = use vid +}; + +// IntOrVidOrTid: a literal int, a runtime Vid, or a tensor (Tid) +struct IntOrVidOrTid { + int64_t literal{0}; + Vid vid{}; + Tid tid{}; + uint8_t kind{0}; // 0 = literal int, 1 = vid, 2 = tid +}; + +// ============================================================================= +// Op node types (AUTO-GENERATED from schema.fbs) +// ============================================================================= + +struct NoopNode { +}; + +struct IdCopyNode { + Tid x; + Tid out; +}; + +struct AddmmNode { + Tid mat1; + Tid mat2; + Tid out; + std::optional bias; + float alpha; + float beta; +}; + +struct ItemIntNode { + Tid x; + Vid out; +}; + +struct ExpandDimsNode { + Tid x; + Tid out; + int32_t axis; +}; + +struct TileNode { + Tid x; + Tid out; + std::vector> reps; +}; + +struct TakeAlongAxisNode { + Tid x; + Tid indices; + Tid out; + int32_t axis; +}; + +struct TakeNode { + Tid x; + Tid out; + IntOrVidOrTid index; + int32_t axis; +}; + +struct RMSNormNode { + Tid x; + std::optional weight; + Tid out; + float eps; +}; + +struct LayerNormNode { + Tid x; + Tid out; + std::optional weight; + std::optional bias; + float eps; +}; + +struct RopeNode { + Tid x; + Tid out; + int32_t dims; + VidOrTid offset; + std::optional freqs; + bool traditional; + float base; + float scale; +}; + +struct SdpaNode { + Tid q; + Tid k; + Tid v; + Tid out; + float scale; + std::optional mask; + bool causal; +}; + +struct AddNode { + Tid a; + Tid b; + Tid out; +}; + +struct AddIntNode { + std::variant a; + std::variant b; + Vid out; +}; + +struct SubtractIntNode { + std::variant a; + std::variant b; + Vid out; +}; + +struct MultiplyIntNode { + std::variant a; + std::variant b; + Vid out; +}; + +struct FloorDivideIntNode { + std::variant a; + std::variant b; + Vid out; +}; + +struct ModIntNode { + std::variant a; + std::variant b; + Vid out; +}; + +struct SymSizeNode { + Tid a; + int32_t dim; + Vid out; +}; + +struct MultiplyNode { + Tid a; + Tid b; + Tid out; +}; + +struct DivideNode { + Tid a; + Tid b; + Tid out; +}; + +struct SubtractNode { + Tid a; + Tid b; + Tid out; +}; + +struct Conv1DNode { + Tid x; + Tid w; + Tid out; + int32_t stride; + int32_t padding; + int32_t dilation; + int32_t groups; +}; + +struct Conv2DNode { + Tid x; + Tid w; + Tid out; + int32_t stride_h; + int32_t stride_w; + int32_t padding_h; + int32_t padding_w; + int32_t dilation_h; + int32_t dilation_w; + int32_t groups; +}; + +struct Conv3DNode { + Tid x; + Tid w; + Tid out; + int32_t stride_d; + int32_t stride_h; + int32_t stride_w; + int32_t padding_d; + int32_t padding_h; + int32_t padding_w; + int32_t dilation_d; + int32_t dilation_h; + int32_t dilation_w; + int32_t groups; +}; + +struct ConvTranspose1DNode { + Tid x; + Tid w; + Tid out; + int32_t stride; + int32_t padding; + int32_t dilation; + int32_t output_padding; + int32_t groups; +}; + +struct ConvTranspose2DNode { + Tid x; + Tid w; + Tid out; + int32_t stride_h; + int32_t stride_w; + int32_t padding_h; + int32_t padding_w; + int32_t dilation_h; + int32_t dilation_w; + int32_t output_padding_h; + int32_t output_padding_w; + int32_t groups; +}; + +struct ConvTranspose3DNode { + Tid x; + Tid w; + Tid out; + int32_t stride_d; + int32_t stride_h; + int32_t stride_w; + int32_t padding_d; + int32_t padding_h; + int32_t padding_w; + int32_t dilation_d; + int32_t dilation_h; + int32_t dilation_w; + int32_t output_padding_d; + int32_t output_padding_h; + int32_t output_padding_w; + int32_t groups; +}; + +struct GeluNode { + Tid x; + Tid out; + std::string approximate; +}; + +struct ARangeNode { + Tid out; + std::variant start; + std::variant stop; + std::variant step; + std::optional scalar_type; +}; + +struct SiluNode { + Tid x; + Tid out; +}; + +struct SigmoidNode { + Tid x; + Tid out; +}; + +struct TanhNode { + Tid x; + Tid out; +}; + +struct SqueezeNode { + Tid x; + Tid out; + std::vector dims; +}; + +struct SplitNode { + Tid x; + std::vector outs; + std::vector> sizes; + int32_t axis; +}; + +struct RsqrtNode { + Tid x; + Tid out; +}; + +struct MaximumNode { + Tid a; + Tid b; + Tid out; +}; + +struct MinimumNode { + Tid a; + Tid b; + Tid out; +}; + +struct LogNode { + Tid x; + Tid out; +}; + +struct SoftmaxNode { + Tid x; + Tid out; + int32_t axis; + bool precise; +}; + +struct BroadcastToNode { + Tid x; + Tid out; + std::vector> shape; +}; + +struct PadNode { + Tid x; + Tid out; + std::vector> pad_width; + std::string mode; + float constant_value; +}; + +struct WhereNode { + Tid condition; + Tid x; + Tid y; + Tid out; +}; + +struct ReshapeNode { + Tid x; + Tid out; + std::vector> shape; +}; + +struct TransposeNode { + Tid x; + Tid out; + std::vector perm; +}; + +struct AsStridedNode { + Tid x; + Tid out; + std::vector> shape; + std::vector> strides; + uint64_t offset; +}; + +struct ContiguousNode { + Tid x; + Tid out; +}; + +struct GatherNode { + Tid x; + std::vector indices; + Tid out; + std::vector axes; + std::vector slice_sizes; +}; + +struct SliceNode { + Tid x; + Tid out; + std::variant axis; + std::variant start; + std::variant stop; + int32_t step; +}; + +struct AsTypeNode { + Tid x; + Tid out; + int8_t scalar_type; +}; + +struct QuantizedMatmulNode { + Tid x; + Tid w; + Tid scales; + Tid out; + std::optional biases; + int32_t group_size; + int32_t bits; + std::string mode; + bool transpose; +}; + +struct ScatterAddNode { + Tid x; + Tid indices; + Tid updates; + Tid out; + int32_t axis; +}; + +struct ConcatenateNode { + std::vector tensors; + Tid out; + int32_t axis; +}; + +struct FullNode { + Tid out; + std::vector> shape; + std::variant v; + int8_t scalar_type; +}; + +struct FullLikeNode { + Tid x; + Tid out; + std::variant v; + std::optional scalar_type; +}; + +struct ArgmaxNode { + Tid x; + Tid out; + int32_t axis; + bool keepdims; +}; + +struct SliceUpdateNode { + Tid dst; + Tid update; + Tid out; + std::variant axis; + std::variant start; + std::variant stop; + int32_t step; +}; + +struct IndexCopyNode { + Tid dst; + Tid update; + Tid indices; + Tid out; + int32_t axis; +}; + +struct DequantizeNode { + Tid w; + Tid scales; + Tid out; + std::optional biases; + int32_t group_size; + int32_t bits; + std::string mode; + std::optional global_scale; + std::optional dtype; +}; + +struct LessNode { + Tid a; + Tid b; + Tid out; +}; + +struct LessEqualNode { + Tid a; + Tid b; + Tid out; +}; + +struct GreaterNode { + Tid a; + Tid b; + Tid out; +}; + +struct GreaterEqualNode { + Tid a; + Tid b; + Tid out; +}; + +struct EqualNode { + Tid a; + Tid b; + Tid out; +}; + +struct NotEqualNode { + Tid a; + Tid b; + Tid out; +}; + +struct LogicalNotNode { + Tid x; + Tid out; +}; + +struct LogicalAndNode { + Tid a; + Tid b; + Tid out; +}; + +struct LogicalOrNode { + Tid a; + Tid b; + Tid out; +}; + +struct TriNode { + Tid out; + std::variant n; + std::variant m; + int32_t k; + int8_t scalar_type; +}; + +struct TrilNode { + Tid x; + Tid out; + int32_t k; +}; + +struct TriuNode { + Tid x; + Tid out; + int32_t k; +}; + +struct ClipNode { + Tid x; + Tid out; + std::optional a_min; + std::optional a_max; +}; + +struct CumsumNode { + Tid x; + Tid out; + int32_t axis; + bool reverse; + bool inclusive; +}; + +struct StackNode { + std::vector tensors; + Tid out; + int32_t axis; +}; + +struct SignNode { + Tid x; + Tid out; +}; + +struct AnyNode { + Tid x; + Tid out; + std::vector axes; + bool keepdims; +}; + +struct AllNode { + Tid x; + Tid out; + std::vector axes; + bool keepdims; +}; + +struct RepeatNode { + Tid x; + Tid out; + std::variant repeats; + int32_t axis; +}; + +struct SortNode { + Tid x; + Tid out; + int32_t axis; +}; + +struct ArgsortNode { + Tid x; + Tid out; + int32_t axis; +}; + +struct PartitionNode { + Tid x; + Tid out; + std::variant kth; + int32_t axis; +}; + +struct ArgPartitionNode { + Tid x; + Tid out; + std::variant kth; + int32_t axis; +}; + +struct FloorNode { + Tid x; + Tid out; +}; + +struct CeilNode { + Tid x; + Tid out; +}; + +struct SquareNode { + Tid x; + Tid out; +}; + +struct ExpNode { + Tid x; + Tid out; +}; + +struct SinNode { + Tid x; + Tid out; +}; + +struct CosNode { + Tid x; + Tid out; +}; + +struct TanNode { + Tid x; + Tid out; +}; + +struct ArcsinNode { + Tid x; + Tid out; +}; + +struct ArccosNode { + Tid x; + Tid out; +}; + +struct ArctanNode { + Tid x; + Tid out; +}; + +struct SinhNode { + Tid x; + Tid out; +}; + +struct CoshNode { + Tid x; + Tid out; +}; + +struct ArcsinhNode { + Tid x; + Tid out; +}; + +struct ArccoshNode { + Tid x; + Tid out; +}; + +struct ArctanhNode { + Tid x; + Tid out; +}; + +struct Log2Node { + Tid x; + Tid out; +}; + +struct Log10Node { + Tid x; + Tid out; +}; + +struct Log1pNode { + Tid x; + Tid out; +}; + +struct ErfNode { + Tid x; + Tid out; +}; + +struct Expm1Node { + Tid x; + Tid out; +}; + +struct RoundNode { + Tid x; + Tid out; + int32_t decimals; +}; + +struct ReciprocalNode { + Tid x; + Tid out; +}; + +struct SqrtNode { + Tid x; + Tid out; +}; + +struct AbsNode { + Tid x; + Tid out; +}; + +struct NegNode { + Tid x; + Tid out; +}; + +struct Atan2Node { + Tid a; + Tid b; + Tid out; +}; + +struct LogAddExpNode { + Tid a; + Tid b; + Tid out; +}; + +struct FloorDivideNode { + Tid a; + Tid b; + Tid out; +}; + +struct RemainderNode { + Tid a; + Tid b; + Tid out; +}; + +struct PowerNode { + Tid a; + Tid b; + Tid out; +}; + +struct LogSumExpNode { + Tid x; + Tid out; + std::vector axes; + bool keepdims; +}; + +struct SumNode { + Tid x; + Tid out; + std::vector axes; + bool keepdims; +}; + +struct MeanNode { + Tid x; + Tid out; + std::vector axes; + bool keepdims; +}; + +struct VarNode { + Tid x; + Tid out; + std::vector axes; + bool keepdims; + int32_t ddof; +}; + +struct StdNode { + Tid x; + Tid out; + std::vector axes; + bool keepdims; + int32_t ddof; +}; + +struct ProdNode { + Tid x; + Tid out; + std::vector axes; + bool keepdims; +}; + +struct MaxNode { + Tid x; + Tid out; + std::vector axes; + bool keepdims; +}; + +struct MinNode { + Tid x; + Tid out; + std::vector axes; + bool keepdims; +}; + +struct ArgminNode { + Tid x; + Tid out; + int32_t axis; + bool keepdims; +}; + +struct MedianNode { + Tid x; + Tid out; + std::vector axes; + bool keepdims; +}; + +struct GatherMmNode { + Tid a; + Tid b; + Tid out; + std::optional lhs_indices; + std::optional rhs_indices; + bool sorted_indices; +}; + +struct GatherQmmNode { + Tid x; + Tid w; + Tid scales; + Tid out; + std::string mode; + std::optional biases; + std::optional lhs_indices; + std::optional rhs_indices; + bool transpose; + int32_t group_size; + int32_t bits; + bool sorted_indices; +}; + +struct ScanNode { + std::vector originals; + std::vector sliced; + std::vector outputs; + std::vector carry; + int32_t body_chain_idx; + int32_t scan_axis; +}; + + +// ============================================================================= +// OpCode enum (AUTO-GENERATED from schema.fbs) +// ============================================================================= + +enum class OpCode : uint8_t { + NOOP, + ID_COPY, + ADDMM, + ITEM_INT, + EXPAND_DIMS, + TILE, + TAKE_ALONG_AXIS, + TAKE, + RMS_NORM, + LAYER_NORM, + ROPE, + SDPA, + ADD, + ADD_INT, + SUBTRACT_INT, + MULTIPLY_INT, + FLOOR_DIVIDE_INT, + MOD_INT, + SYM_SIZE, + MULTIPLY, + DIVIDE, + SUBTRACT, + CONV1D, + CONV2D, + CONV3D, + CONV_TRANSPOSE1D, + CONV_TRANSPOSE2D, + CONV_TRANSPOSE3D, + GELU, + ARANGE, + SILU, + SIGMOID, + TANH, + SQUEEZE, + SPLIT, + RSQRT, + MAXIMUM, + MINIMUM, + LOG, + SOFTMAX, + BROADCAST_TO, + PAD, + WHERE, + RESHAPE, + TRANSPOSE, + AS_STRIDED, + CONTIGUOUS, + GATHER, + SLICE, + ASTYPE, + QUANTIZED_MATMUL, + SCATTER_ADD, + CONCATENATE, + FULL, + FULL_LIKE, + ARGMAX, + SLICE_UPDATE, + INDEX_COPY, + DEQUANTIZE, + LESS, + LESS_EQUAL, + GREATER, + GREATER_EQUAL, + EQUAL, + NOT_EQUAL, + LOGICAL_NOT, + LOGICAL_AND, + LOGICAL_OR, + TRI, + TRIL, + TRIU, + CLIP, + CUMSUM, + STACK, + SIGN, + ANY, + ALL, + REPEAT, + SORT, + ARGSORT, + PARTITION, + ARG_PARTITION, + FLOOR, + CEIL, + SQUARE, + EXP, + SIN, + COS, + TAN, + ARCSIN, + ARCCOS, + ARCTAN, + SINH, + COSH, + ARCSINH, + ARCCOSH, + ARCTANH, + LOG2, + LOG10, + LOG1P, + ERF, + EXPM1, + ROUND, + RECIPROCAL, + SQRT, + ABS, + NEG, + ATAN2, + LOG_ADD_EXP, + FLOOR_DIVIDE, + REMAINDER, + POWER, + LOG_SUM_EXP, + SUM, + MEAN, + VAR, + STD, + PROD, + MAX, + MIN, + ARGMIN, + MEDIAN, + GATHER_MM, + GATHER_QMM, + SCAN, +}; + +// OpCode to string conversion (for logging) +inline const char* op_name(OpCode op) { + switch (op) { + case OpCode::NOOP: + return "NOOP"; + case OpCode::ID_COPY: + return "ID_COPY"; + case OpCode::ADDMM: + return "ADDMM"; + case OpCode::ITEM_INT: + return "ITEM_INT"; + case OpCode::EXPAND_DIMS: + return "EXPAND_DIMS"; + case OpCode::TILE: + return "TILE"; + case OpCode::TAKE_ALONG_AXIS: + return "TAKE_ALONG_AXIS"; + case OpCode::TAKE: + return "TAKE"; + case OpCode::RMS_NORM: + return "RMS_NORM"; + case OpCode::LAYER_NORM: + return "LAYER_NORM"; + case OpCode::ROPE: + return "ROPE"; + case OpCode::SDPA: + return "SDPA"; + case OpCode::ADD: + return "ADD"; + case OpCode::ADD_INT: + return "ADD_INT"; + case OpCode::SUBTRACT_INT: + return "SUBTRACT_INT"; + case OpCode::MULTIPLY_INT: + return "MULTIPLY_INT"; + case OpCode::FLOOR_DIVIDE_INT: + return "FLOOR_DIVIDE_INT"; + case OpCode::MOD_INT: + return "MOD_INT"; + case OpCode::SYM_SIZE: + return "SYM_SIZE"; + case OpCode::MULTIPLY: + return "MULTIPLY"; + case OpCode::DIVIDE: + return "DIVIDE"; + case OpCode::SUBTRACT: + return "SUBTRACT"; + case OpCode::CONV1D: + return "CONV1D"; + case OpCode::CONV2D: + return "CONV2D"; + case OpCode::CONV3D: + return "CONV3D"; + case OpCode::CONV_TRANSPOSE1D: + return "CONV_TRANSPOSE1D"; + case OpCode::CONV_TRANSPOSE2D: + return "CONV_TRANSPOSE2D"; + case OpCode::CONV_TRANSPOSE3D: + return "CONV_TRANSPOSE3D"; + case OpCode::GELU: + return "GELU"; + case OpCode::ARANGE: + return "ARANGE"; + case OpCode::SILU: + return "SILU"; + case OpCode::SIGMOID: + return "SIGMOID"; + case OpCode::TANH: + return "TANH"; + case OpCode::SQUEEZE: + return "SQUEEZE"; + case OpCode::SPLIT: + return "SPLIT"; + case OpCode::RSQRT: + return "RSQRT"; + case OpCode::MAXIMUM: + return "MAXIMUM"; + case OpCode::MINIMUM: + return "MINIMUM"; + case OpCode::LOG: + return "LOG"; + case OpCode::SOFTMAX: + return "SOFTMAX"; + case OpCode::BROADCAST_TO: + return "BROADCAST_TO"; + case OpCode::PAD: + return "PAD"; + case OpCode::WHERE: + return "WHERE"; + case OpCode::RESHAPE: + return "RESHAPE"; + case OpCode::TRANSPOSE: + return "TRANSPOSE"; + case OpCode::AS_STRIDED: + return "AS_STRIDED"; + case OpCode::CONTIGUOUS: + return "CONTIGUOUS"; + case OpCode::GATHER: + return "GATHER"; + case OpCode::SLICE: + return "SLICE"; + case OpCode::ASTYPE: + return "ASTYPE"; + case OpCode::QUANTIZED_MATMUL: + return "QUANTIZED_MATMUL"; + case OpCode::SCATTER_ADD: + return "SCATTER_ADD"; + case OpCode::CONCATENATE: + return "CONCATENATE"; + case OpCode::FULL: + return "FULL"; + case OpCode::FULL_LIKE: + return "FULL_LIKE"; + case OpCode::ARGMAX: + return "ARGMAX"; + case OpCode::SLICE_UPDATE: + return "SLICE_UPDATE"; + case OpCode::INDEX_COPY: + return "INDEX_COPY"; + case OpCode::DEQUANTIZE: + return "DEQUANTIZE"; + case OpCode::LESS: + return "LESS"; + case OpCode::LESS_EQUAL: + return "LESS_EQUAL"; + case OpCode::GREATER: + return "GREATER"; + case OpCode::GREATER_EQUAL: + return "GREATER_EQUAL"; + case OpCode::EQUAL: + return "EQUAL"; + case OpCode::NOT_EQUAL: + return "NOT_EQUAL"; + case OpCode::LOGICAL_NOT: + return "LOGICAL_NOT"; + case OpCode::LOGICAL_AND: + return "LOGICAL_AND"; + case OpCode::LOGICAL_OR: + return "LOGICAL_OR"; + case OpCode::TRI: + return "TRI"; + case OpCode::TRIL: + return "TRIL"; + case OpCode::TRIU: + return "TRIU"; + case OpCode::CLIP: + return "CLIP"; + case OpCode::CUMSUM: + return "CUMSUM"; + case OpCode::STACK: + return "STACK"; + case OpCode::SIGN: + return "SIGN"; + case OpCode::ANY: + return "ANY"; + case OpCode::ALL: + return "ALL"; + case OpCode::REPEAT: + return "REPEAT"; + case OpCode::SORT: + return "SORT"; + case OpCode::ARGSORT: + return "ARGSORT"; + case OpCode::PARTITION: + return "PARTITION"; + case OpCode::ARG_PARTITION: + return "ARG_PARTITION"; + case OpCode::FLOOR: + return "FLOOR"; + case OpCode::CEIL: + return "CEIL"; + case OpCode::SQUARE: + return "SQUARE"; + case OpCode::EXP: + return "EXP"; + case OpCode::SIN: + return "SIN"; + case OpCode::COS: + return "COS"; + case OpCode::TAN: + return "TAN"; + case OpCode::ARCSIN: + return "ARCSIN"; + case OpCode::ARCCOS: + return "ARCCOS"; + case OpCode::ARCTAN: + return "ARCTAN"; + case OpCode::SINH: + return "SINH"; + case OpCode::COSH: + return "COSH"; + case OpCode::ARCSINH: + return "ARCSINH"; + case OpCode::ARCCOSH: + return "ARCCOSH"; + case OpCode::ARCTANH: + return "ARCTANH"; + case OpCode::LOG2: + return "LOG2"; + case OpCode::LOG10: + return "LOG10"; + case OpCode::LOG1P: + return "LOG1P"; + case OpCode::ERF: + return "ERF"; + case OpCode::EXPM1: + return "EXPM1"; + case OpCode::ROUND: + return "ROUND"; + case OpCode::RECIPROCAL: + return "RECIPROCAL"; + case OpCode::SQRT: + return "SQRT"; + case OpCode::ABS: + return "ABS"; + case OpCode::NEG: + return "NEG"; + case OpCode::ATAN2: + return "ATAN2"; + case OpCode::LOG_ADD_EXP: + return "LOG_ADD_EXP"; + case OpCode::FLOOR_DIVIDE: + return "FLOOR_DIVIDE"; + case OpCode::REMAINDER: + return "REMAINDER"; + case OpCode::POWER: + return "POWER"; + case OpCode::LOG_SUM_EXP: + return "LOG_SUM_EXP"; + case OpCode::SUM: + return "SUM"; + case OpCode::MEAN: + return "MEAN"; + case OpCode::VAR: + return "VAR"; + case OpCode::STD: + return "STD"; + case OpCode::PROD: + return "PROD"; + case OpCode::MAX: + return "MAX"; + case OpCode::MIN: + return "MIN"; + case OpCode::ARGMIN: + return "ARGMIN"; + case OpCode::MEDIAN: + return "MEDIAN"; + case OpCode::GATHER_MM: + return "GATHER_MM"; + case OpCode::GATHER_QMM: + return "GATHER_QMM"; + case OpCode::SCAN: + return "SCAN"; + } + return "UNKNOWN"; +} + +// ============================================================================= +// NodeVariant for type-erased op storage (AUTO-GENERATED) +// ============================================================================= + +using NodeVariant = std::variant< + NoopNode, + IdCopyNode, + AddmmNode, + ItemIntNode, + ExpandDimsNode, + TileNode, + TakeAlongAxisNode, + TakeNode, + RMSNormNode, + LayerNormNode, + RopeNode, + SdpaNode, + AddNode, + AddIntNode, + SubtractIntNode, + MultiplyIntNode, + FloorDivideIntNode, + ModIntNode, + SymSizeNode, + MultiplyNode, + DivideNode, + SubtractNode, + Conv1DNode, + Conv2DNode, + Conv3DNode, + ConvTranspose1DNode, + ConvTranspose2DNode, + ConvTranspose3DNode, + GeluNode, + ARangeNode, + SiluNode, + SigmoidNode, + TanhNode, + SqueezeNode, + SplitNode, + RsqrtNode, + MaximumNode, + MinimumNode, + LogNode, + SoftmaxNode, + BroadcastToNode, + PadNode, + WhereNode, + ReshapeNode, + TransposeNode, + AsStridedNode, + ContiguousNode, + GatherNode, + SliceNode, + AsTypeNode, + QuantizedMatmulNode, + ScatterAddNode, + ConcatenateNode, + FullNode, + FullLikeNode, + ArgmaxNode, + SliceUpdateNode, + IndexCopyNode, + DequantizeNode, + LessNode, + LessEqualNode, + GreaterNode, + GreaterEqualNode, + EqualNode, + NotEqualNode, + LogicalNotNode, + LogicalAndNode, + LogicalOrNode, + TriNode, + TrilNode, + TriuNode, + ClipNode, + CumsumNode, + StackNode, + SignNode, + AnyNode, + AllNode, + RepeatNode, + SortNode, + ArgsortNode, + PartitionNode, + ArgPartitionNode, + FloorNode, + CeilNode, + SquareNode, + ExpNode, + SinNode, + CosNode, + TanNode, + ArcsinNode, + ArccosNode, + ArctanNode, + SinhNode, + CoshNode, + ArcsinhNode, + ArccoshNode, + ArctanhNode, + Log2Node, + Log10Node, + Log1pNode, + ErfNode, + Expm1Node, + RoundNode, + ReciprocalNode, + SqrtNode, + AbsNode, + NegNode, + Atan2Node, + LogAddExpNode, + FloorDivideNode, + RemainderNode, + PowerNode, + LogSumExpNode, + SumNode, + MeanNode, + VarNode, + StdNode, + ProdNode, + MaxNode, + MinNode, + ArgminNode, + MedianNode, + GatherMmNode, + GatherQmmNode, + ScanNode +>; + +// ============================================================================= +// Instruction +// ============================================================================= + +struct Instruction { + OpCode op{OpCode::NOOP}; + NodeVariant node; + + template + T& get() { + return std::get(node); + } + + template + const T& get() const { + return std::get(node); + } +}; + +// ============================================================================= +// SlotVariant for I/O mapping +// ============================================================================= + +enum class SlotType : uint8_t { + TensorSlot = 0, + IntValueSlot = 1, + FloatValueSlot = 2, + BoolValueSlot = 3, +}; + +struct SlotVariant { + uint32_t idx; + SlotType slot_type; +}; + +// ============================================================================= +// Named slot (name -> slot mapping) +// ============================================================================= + +struct NamedSlot { + std::string name; + SlotVariant slot; +}; + +// ============================================================================= +// MLXProgram - the loaded program ready for execution +// ============================================================================= + +struct MLXProgram { + std::string version; + + // Tensor/value slot counts (in Tid assignment order) + uint32_t num_constant_tensors{0}; + uint32_t num_input_tensors{0}; + uint32_t num_output_tensors{0}; + uint32_t num_mutable_buffer_tensors{0}; + uint32_t num_temp_tensors{0}; + uint32_t num_values{0}; + + // Instruction chains + std::vector> instruction_chains; + uint32_t main_chain_idx{0}; + int32_t init_chain_idx{-1}; // -1 = no init chain + + // I/O mappings + std::vector input_map; + std::vector output_map; + std::vector mutable_buffer_map; + + // Name to slot lookup + std::vector named_slots; + + // Tensor metadata + std::vector> tensor_meta; + + // Helper methods + inline uint64_t num_tensors() const { + return static_cast(num_constant_tensors) + + num_input_tensors + num_output_tensors + + num_mutable_buffer_tensors + num_temp_tensors; + } + + inline bool is_constant_tensor(Tid id) const { + return id.idx < num_constant_tensors; + } + + inline size_t num_inputs() const { + return input_map.size(); + } + + inline size_t num_outputs() const { + return output_map.size(); + } +}; + +// ============================================================================= +// FlatBuffer loading functions +// ============================================================================= + +namespace loader { + +// Convert FlatBuffer SlotType to our SlotType +inline SlotType convert_slot_type(mlx_delegate::SlotType fb_type) { + switch (fb_type) { + case mlx_delegate::SlotType_TensorSlot: + return SlotType::TensorSlot; + case mlx_delegate::SlotType_IntValueSlot: + return SlotType::IntValueSlot; + case mlx_delegate::SlotType_FloatValueSlot: + return SlotType::FloatValueSlot; + case mlx_delegate::SlotType_BoolValueSlot: + return SlotType::BoolValueSlot; + default: + throw std::runtime_error("Unknown SlotType: " + + std::to_string(static_cast(fb_type))); + } +} + +// Convert FlatBuffer Tid +inline Tid convert_tid(const mlx_delegate::Tid* fb_tid) { + if (!fb_tid) { + throw std::runtime_error("Null Tid in FlatBuffer"); + } + return Tid{fb_tid->idx()}; +} + +// Convert FlatBuffer Vid +inline Vid convert_vid(const mlx_delegate::Vid* fb_vid) { + if (!fb_vid) { + throw std::runtime_error("Null Vid in FlatBuffer"); + } + return Vid{fb_vid->idx()}; +} + +// Convert FlatBuffer IntOrVid +inline std::variant convert_int_or_vid( + const mlx_delegate::IntOrVid* fb) { + if (!fb) { + throw std::runtime_error("Null IntOrVid in FlatBuffer"); + } + if (!fb->is_vid()) { + return fb->literal(); + } + const auto* vid_ptr = fb->vid(); + if (!vid_ptr) { + throw std::runtime_error("IntOrVid has is_vid=true but vid pointer is null"); + } + return Vid{vid_ptr->idx()}; +} + +// Convert FlatBuffer FloatOrVid +inline std::variant convert_float_or_vid( + const mlx_delegate::FloatOrVid* fb) { + if (!fb) { + throw std::runtime_error("Null FloatOrVid in FlatBuffer"); + } + if (!fb->is_vid()) { + return fb->literal(); + } + const auto* vid_ptr = fb->vid(); + if (!vid_ptr) { + throw std::runtime_error("FloatOrVid has is_vid=true but vid pointer is null"); + } + return Vid{vid_ptr->idx()}; +} + +// Convert FlatBuffer VidOrTid (scalar value or tensor) +inline VidOrTid convert_vid_or_tid( + const mlx_delegate::VidOrTid* fb) { + if (!fb) { + throw std::runtime_error("Null VidOrTid in FlatBuffer"); + } + VidOrTid result; + result.is_vid = fb->is_vid(); + if (result.is_vid) { + if (!fb->vid()) { + throw std::runtime_error("VidOrTid has is_vid=true but vid pointer is null"); + } + result.vid = Vid{fb->vid()->idx()}; + } else { + if (!fb->tid()) { + throw std::runtime_error("VidOrTid has is_vid=false but tid pointer is null"); + } + result.tid = Tid{fb->tid()->idx()}; + } + return result; +} + +// Convert FlatBuffer IntOrVidOrTid (literal int, Vid, or Tid) +inline IntOrVidOrTid convert_int_or_vid_or_tid( + const mlx_delegate::IntOrVidOrTid* fb) { + if (!fb) { + throw std::runtime_error("Null IntOrVidOrTid in FlatBuffer"); + } + IntOrVidOrTid result; + result.kind = fb->kind(); + switch (result.kind) { + case 0: // literal int + result.literal = fb->literal(); + break; + case 1: { // Vid + const auto* vid_ptr = fb->vid(); + if (!vid_ptr) { + throw std::runtime_error( + "IntOrVidOrTid has kind=1 (Vid) but vid pointer is null"); + } + result.vid = Vid{vid_ptr->idx()}; + break; + } + case 2: { // Tid + const auto* tid_ptr = fb->tid(); + if (!tid_ptr) { + throw std::runtime_error( + "IntOrVidOrTid has kind=2 (Tid) but tid pointer is null"); + } + result.tid = Tid{tid_ptr->idx()}; + break; + } + default: + throw std::runtime_error( + "IntOrVidOrTid has invalid kind: " + std::to_string(result.kind)); + } + return result; +} + +// Convert FlatBuffer SlotVariant +inline SlotVariant convert_slot_variant(const mlx_delegate::SlotVariant* fb) { + if (!fb) { + throw std::runtime_error("Null SlotVariant in FlatBuffer"); + } + return SlotVariant{fb->idx(), convert_slot_type(fb->slot_type())}; +} + +// Load an instruction from FlatBuffer +Instruction load_instruction(const mlx_delegate::Instruction* fb_instr); + +// Load the full MLXProgram from FlatBuffer data +MLXProgram load_program(const void* data, size_t size); + +} // namespace loader + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/serialization/_generated/__init__.py b/backends/mlx/serialization/_generated/__init__.py new file mode 100644 index 00000000000..999f694a9f9 --- /dev/null +++ b/backends/mlx/serialization/_generated/__init__.py @@ -0,0 +1,147 @@ +# Auto-generated FlatBuffer bindings +# Re-exports from mlx_delegate namespace for convenient imports + +from executorch.backends.mlx.serialization._generated.mlx_delegate.ARangeNode import ARangeNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.AbsNode import AbsNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.AddIntNode import AddIntNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.AddNode import AddNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.AddmmNode import AddmmNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.AllNode import AllNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.AnyNode import AnyNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ArccosNode import ArccosNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ArccoshNode import ArccoshNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ArcsinNode import ArcsinNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ArcsinhNode import ArcsinhNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ArctanNode import ArctanNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ArctanhNode import ArctanhNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ArgPartitionNode import ArgPartitionNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ArgmaxNode import ArgmaxNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ArgminNode import ArgminNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ArgsortNode import ArgsortNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.AsStridedNode import AsStridedNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.AsTypeNode import AsTypeNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.Atan2Node import Atan2Node +from executorch.backends.mlx.serialization._generated.mlx_delegate.BitwiseOrNode import BitwiseOrNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.BroadcastToNode import BroadcastToNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.CeilNode import CeilNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ClipNode import ClipNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ConcatenateNode import ConcatenateNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ContiguousNode import ContiguousNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.Conv1DNode import Conv1DNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.Conv2DNode import Conv2DNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.Conv3DNode import Conv3DNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ConvTranspose1DNode import ConvTranspose1DNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ConvTranspose2DNode import ConvTranspose2DNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ConvTranspose3DNode import ConvTranspose3DNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.CosNode import CosNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.CoshNode import CoshNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.CumsumNode import CumsumNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.DequantizeNode import DequantizeNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.DivideNode import DivideNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.EqualNode import EqualNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ErfNode import ErfNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ExpNode import ExpNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ExpandDimsNode import ExpandDimsNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.Expm1Node import Expm1Node +from executorch.backends.mlx.serialization._generated.mlx_delegate.FloatOrVid import FloatOrVid +from executorch.backends.mlx.serialization._generated.mlx_delegate.FloorDivideIntNode import FloorDivideIntNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.FloorDivideNode import FloorDivideNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.FloorNode import FloorNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.FullLikeNode import FullLikeNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.FullNode import FullNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.GatherMmNode import GatherMmNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.GatherNode import GatherNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.GatherQmmNode import GatherQmmNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.GeluNode import GeluNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.GreaterEqualNode import GreaterEqualNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.GreaterNode import GreaterNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.IdCopyNode import IdCopyNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.IndexCopyNode import IndexCopyNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.Instruction import Instruction +from executorch.backends.mlx.serialization._generated.mlx_delegate.InstructionChain import InstructionChain +from executorch.backends.mlx.serialization._generated.mlx_delegate.IntOrVid import IntOrVid +from executorch.backends.mlx.serialization._generated.mlx_delegate.IntOrVidOrTid import IntOrVidOrTid +from executorch.backends.mlx.serialization._generated.mlx_delegate.ItemIntNode import ItemIntNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.LayerNormNode import LayerNormNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.LessEqualNode import LessEqualNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.LessNode import LessNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.Log10Node import Log10Node +from executorch.backends.mlx.serialization._generated.mlx_delegate.Log1pNode import Log1pNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.Log2Node import Log2Node +from executorch.backends.mlx.serialization._generated.mlx_delegate.LogAddExpNode import LogAddExpNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.LogNode import LogNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.LogSumExpNode import LogSumExpNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.LogicalAndNode import LogicalAndNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.LogicalNotNode import LogicalNotNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.LogicalOrNode import LogicalOrNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.MLXGraph import MLXGraph +from executorch.backends.mlx.serialization._generated.mlx_delegate.MaxNode import MaxNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.MaximumNode import MaximumNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.MeanNode import MeanNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.MedianNode import MedianNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.MetalKernelNode import MetalKernelNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.MinNode import MinNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.MinimumNode import MinimumNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ModIntNode import ModIntNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.MultiplyIntNode import MultiplyIntNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.MultiplyNode import MultiplyNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.NamedSlot import NamedSlot +from executorch.backends.mlx.serialization._generated.mlx_delegate.NegNode import NegNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.NoopNode import NoopNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.NotEqualNode import NotEqualNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.OpNode import OpNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.PadNode import PadNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.PartitionNode import PartitionNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.PowerNode import PowerNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ProdNode import ProdNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.QuantizedMatmulNode import QuantizedMatmulNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.RMSNormNode import RMSNormNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ReciprocalNode import ReciprocalNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.RemainderNode import RemainderNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.RepeatNode import RepeatNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ReshapeNode import ReshapeNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.RopeNode import RopeNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.RoundNode import RoundNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.RsqrtNode import RsqrtNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ScanNode import ScanNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ScatterAddNode import ScatterAddNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SdpaNode import SdpaNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.ShapeDim import ShapeDim +from executorch.backends.mlx.serialization._generated.mlx_delegate.SigmoidNode import SigmoidNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SignNode import SignNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SiluNode import SiluNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SinNode import SinNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SinhNode import SinhNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SliceNode import SliceNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SliceUpdateNode import SliceUpdateNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SlotType import SlotType +from executorch.backends.mlx.serialization._generated.mlx_delegate.SlotVariant import SlotVariant +from executorch.backends.mlx.serialization._generated.mlx_delegate.SoftmaxNode import SoftmaxNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SortNode import SortNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SplitNode import SplitNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SqrtNode import SqrtNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SquareNode import SquareNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SqueezeNode import SqueezeNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.StackNode import StackNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.StdNode import StdNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SubtractIntNode import SubtractIntNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SubtractNode import SubtractNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SumNode import SumNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.SymSizeNode import SymSizeNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.TakeAlongAxisNode import TakeAlongAxisNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.TakeNode import TakeNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.TanNode import TanNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.TanhNode import TanhNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.TensorMeta import TensorMeta +from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import Tid +from executorch.backends.mlx.serialization._generated.mlx_delegate.TileNode import TileNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.TransposeNode import TransposeNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.TriNode import TriNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.TrilNode import TrilNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.TriuNode import TriuNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.VarNode import VarNode +from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import Vid +from executorch.backends.mlx.serialization._generated.mlx_delegate.VidOrTid import VidOrTid +from executorch.backends.mlx.serialization._generated.mlx_delegate.WhereNode import WhereNode + +__all__ = ['ARangeNode', 'AbsNode', 'AddIntNode', 'AddNode', 'AddmmNode', 'AllNode', 'AnyNode', 'ArccosNode', 'ArccoshNode', 'ArcsinNode', 'ArcsinhNode', 'ArctanNode', 'ArctanhNode', 'ArgPartitionNode', 'ArgmaxNode', 'ArgminNode', 'ArgsortNode', 'AsStridedNode', 'AsTypeNode', 'Atan2Node', 'BitwiseOrNode', 'BroadcastToNode', 'CeilNode', 'ClipNode', 'ConcatenateNode', 'ContiguousNode', 'Conv1DNode', 'Conv2DNode', 'Conv3DNode', 'ConvTranspose1DNode', 'ConvTranspose2DNode', 'ConvTranspose3DNode', 'CosNode', 'CoshNode', 'CumsumNode', 'DequantizeNode', 'DivideNode', 'EqualNode', 'ErfNode', 'ExpNode', 'ExpandDimsNode', 'Expm1Node', 'FloatOrVid', 'FloorDivideIntNode', 'FloorDivideNode', 'FloorNode', 'FullLikeNode', 'FullNode', 'GatherMmNode', 'GatherNode', 'GatherQmmNode', 'GeluNode', 'GreaterEqualNode', 'GreaterNode', 'IdCopyNode', 'IndexCopyNode', 'Instruction', 'InstructionChain', 'IntOrVid', 'IntOrVidOrTid', 'ItemIntNode', 'LayerNormNode', 'LessEqualNode', 'LessNode', 'Log10Node', 'Log1pNode', 'Log2Node', 'LogAddExpNode', 'LogNode', 'LogSumExpNode', 'LogicalAndNode', 'LogicalNotNode', 'LogicalOrNode', 'MLXGraph', 'MaxNode', 'MaximumNode', 'MeanNode', 'MedianNode', 'MetalKernelNode', 'MinNode', 'MinimumNode', 'ModIntNode', 'MultiplyIntNode', 'MultiplyNode', 'NamedSlot', 'NegNode', 'NoopNode', 'NotEqualNode', 'OpNode', 'PadNode', 'PartitionNode', 'PowerNode', 'ProdNode', 'QuantizedMatmulNode', 'RMSNormNode', 'ReciprocalNode', 'RemainderNode', 'RepeatNode', 'ReshapeNode', 'RopeNode', 'RoundNode', 'RsqrtNode', 'ScanNode', 'ScatterAddNode', 'SdpaNode', 'ShapeDim', 'SigmoidNode', 'SignNode', 'SiluNode', 'SinNode', 'SinhNode', 'SliceNode', 'SliceUpdateNode', 'SlotType', 'SlotVariant', 'SoftmaxNode', 'SortNode', 'SplitNode', 'SqrtNode', 'SquareNode', 'SqueezeNode', 'StackNode', 'StdNode', 'SubtractIntNode', 'SubtractNode', 'SumNode', 'SymSizeNode', 'TakeAlongAxisNode', 'TakeNode', 'TanNode', 'TanhNode', 'TensorMeta', 'Tid', 'TileNode', 'TransposeNode', 'TriNode', 'TrilNode', 'TriuNode', 'VarNode', 'Vid', 'VidOrTid', 'WhereNode'] diff --git a/backends/mlx/serialization/_generated_serializers.py b/backends/mlx/serialization/_generated_serializers.py new file mode 100644 index 00000000000..57aa3365700 --- /dev/null +++ b/backends/mlx/serialization/_generated_serializers.py @@ -0,0 +1,2777 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# ============================================================================ +# AUTO-GENERATED FILE - DO NOT EDIT MANUALLY +# ============================================================================ +# +# This file was generated from schema.fbs by the MLX delegate code generator. +# +# Source: backends/mlx/serialization/schema.fbs +# Generator: backends/mlx/serialization/generate.py +# +# To regenerate, run from the executorch root: +# python backends/mlx/serialization/generate.py +# +# ============================================================================ +# +# This file contains auto-generated serializer methods for all op types. + +from __future__ import annotations + +from typing import List, Tuple, Dict + +import flatbuffers + +# FlatBuffer union indices: 0 = NONE, then 1-indexed from union order +MLX_OP_TYPE_NAMES = { + 0: "NONE", + 1: "NoopNode", + 2: "IdCopyNode", + 3: "AddmmNode", + 4: "ItemIntNode", + 5: "ExpandDimsNode", + 6: "TileNode", + 7: "TakeAlongAxisNode", + 8: "TakeNode", + 9: "RMSNormNode", + 10: "LayerNormNode", + 11: "RopeNode", + 12: "SdpaNode", + 13: "AddNode", + 14: "AddIntNode", + 15: "SubtractIntNode", + 16: "MultiplyIntNode", + 17: "FloorDivideIntNode", + 18: "SymSizeNode", + 19: "MultiplyNode", + 20: "DivideNode", + 21: "SubtractNode", + 22: "Conv1DNode", + 23: "Conv2DNode", + 24: "Conv3DNode", + 25: "GeluNode", + 26: "ARangeNode", + 27: "SiluNode", + 28: "SigmoidNode", + 29: "TanhNode", + 30: "SqueezeNode", + 31: "SplitNode", + 32: "RsqrtNode", + 33: "MaximumNode", + 34: "MinimumNode", + 35: "LogNode", + 36: "SoftmaxNode", + 37: "BroadcastToNode", + 38: "PadNode", + 39: "WhereNode", + 40: "ReshapeNode", + 41: "TransposeNode", + 42: "AsStridedNode", + 43: "ContiguousNode", + 44: "GatherNode", + 45: "SliceNode", + 46: "AsTypeNode", + 47: "ConcatenateNode", + 48: "FullNode", + 49: "FullLikeNode", + 50: "ArgmaxNode", + 51: "SliceUpdateNode", + 52: "IndexCopyNode", + 53: "DequantizeNode", + 54: "LessNode", + 55: "LessEqualNode", + 56: "GreaterNode", + 57: "GreaterEqualNode", + 58: "EqualNode", + 59: "NotEqualNode", + 60: "LogicalNotNode", + 61: "LogicalAndNode", + 62: "LogicalOrNode", + 63: "TriNode", + 64: "TrilNode", + 65: "TriuNode", + 66: "FloorNode", + 67: "CeilNode", + 68: "SquareNode", + 69: "ExpNode", + 70: "SinNode", + 71: "CosNode", + 72: "TanNode", + 73: "ArcsinNode", + 74: "ArccosNode", + 75: "ArctanNode", + 76: "SinhNode", + 77: "CoshNode", + 78: "ArcsinhNode", + 79: "ArccoshNode", + 80: "ArctanhNode", + 81: "Log2Node", + 82: "Log10Node", + 83: "Log1pNode", + 84: "ErfNode", + 85: "Expm1Node", + 86: "RoundNode", + 87: "ReciprocalNode", + 88: "SqrtNode", + 89: "AbsNode", + 90: "NegNode", + 91: "Atan2Node", + 92: "LogAddExpNode", + 93: "FloorDivideNode", + 94: "PowerNode", + 95: "LogSumExpNode", + 96: "SumNode", + 97: "MeanNode", + 98: "VarNode", + 99: "StdNode", + 100: "ProdNode", + 101: "MaxNode", + 102: "MinNode", + 103: "ArgminNode", + 104: "MedianNode", + 105: "ModIntNode", + 106: "RemainderNode", + 107: "ConvTranspose1DNode", + 108: "ConvTranspose2DNode", + 109: "ConvTranspose3DNode", + 110: "ClipNode", + 111: "CumsumNode", + 112: "StackNode", + 113: "SignNode", + 114: "AnyNode", + 115: "AllNode", + 116: "RepeatNode", + 117: "SortNode", + 118: "ArgsortNode", + 119: "PartitionNode", + 120: "ArgPartitionNode", + 121: "QuantizedMatmulNode", + 122: "ScatterAddNode", + 123: "GatherMmNode", + 124: "GatherQmmNode", + 125: "ScanNode", + 126: "MetalKernelNode + BitwiseOrNode", +} + +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + NoopNode, + IdCopyNode, + AddmmNode, + ItemIntNode, + ExpandDimsNode, + TileNode, + TakeAlongAxisNode, + TakeNode, + RMSNormNode, + LayerNormNode, + RopeNode, + SdpaNode, + AddNode, + AddIntNode, + SubtractIntNode, + MultiplyIntNode, + FloorDivideIntNode, + ModIntNode, + SymSizeNode, + MultiplyNode, + DivideNode, + SubtractNode, + Conv1DNode, + Conv2DNode, + Conv3DNode, + ConvTranspose1DNode, + ConvTranspose2DNode, + ConvTranspose3DNode, + GeluNode, + ARangeNode, + SiluNode, + SigmoidNode, + TanhNode, + SqueezeNode, + SplitNode, + RsqrtNode, + MaximumNode, + MinimumNode, + LogNode, + SoftmaxNode, + BroadcastToNode, + PadNode, + WhereNode, + ReshapeNode, + TransposeNode, + AsStridedNode, + ContiguousNode, + GatherNode, + SliceNode, + AsTypeNode, + QuantizedMatmulNode, + ScatterAddNode, + ConcatenateNode, + FullNode, + FullLikeNode, + ArgmaxNode, + SliceUpdateNode, + IndexCopyNode, + DequantizeNode, + LessNode, + LessEqualNode, + GreaterNode, + GreaterEqualNode, + EqualNode, + NotEqualNode, + LogicalNotNode, + LogicalAndNode, + LogicalOrNode, + TriNode, + TrilNode, + TriuNode, + ClipNode, + CumsumNode, + StackNode, + SignNode, + AnyNode, + AllNode, + RepeatNode, + SortNode, + ArgsortNode, + PartitionNode, + ArgPartitionNode, + FloorNode, + CeilNode, + SquareNode, + ExpNode, + SinNode, + CosNode, + TanNode, + ArcsinNode, + ArccosNode, + ArctanNode, + SinhNode, + CoshNode, + ArcsinhNode, + ArccoshNode, + ArctanhNode, + Log2Node, + Log10Node, + Log1pNode, + ErfNode, + Expm1Node, + RoundNode, + ReciprocalNode, + SqrtNode, + AbsNode, + NegNode, + Atan2Node, + LogAddExpNode, + FloorDivideNode, + RemainderNode, + PowerNode, + LogSumExpNode, + SumNode, + MeanNode, + VarNode, + StdNode, + ProdNode, + MaxNode, + MinNode, + ArgminNode, + MedianNode, + GatherMmNode, + GatherQmmNode, + ScanNode, + IntOrVid, + FloatOrVid, + VidOrTid, + IntOrVidOrTid, + Tid, + Vid, +) + + +def _build_int_vector(builder: flatbuffers.Builder, vec: List[int]) -> int: + """Pre-build a vector of int32 values (must be called before table Start).""" + builder.StartVector(4, len(vec), 4) + for v in reversed(vec): + builder.PrependInt32(v) + return builder.EndVector() + + +def _build_int8_vector(builder: flatbuffers.Builder, vec: List[int]) -> int: + """Pre-build a vector of int8 values (must be called before table Start).""" + builder.StartVector(1, len(vec), 1) + for v in reversed(vec): + builder.PrependInt8(v) + return builder.EndVector() + + +def _build_uint8_vector(builder: flatbuffers.Builder, vec: List[int]) -> int: + """Pre-build a vector of uint8 values (must be called before table Start).""" + builder.StartVector(1, len(vec), 1) + for v in reversed(vec): + builder.PrependUint8(v) + return builder.EndVector() + + +class GeneratedOpBuilders: + """Mixin class with auto-generated op builder methods.""" + + def _build_int_or_vid(self, builder: flatbuffers.Builder, iov: IntOrVid) -> int: + """Build an IntOrVid table.""" + from executorch.backends.mlx.serialization._generated.mlx_delegate import IntOrVid as FBIntOrVidModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBIntOrVidModule.Start(builder) + FBIntOrVidModule.AddLiteral(builder, iov.literal) + FBIntOrVidModule.AddIsVid(builder, iov.is_vid) + if iov.vid is not None: + # Vid is an inline struct - must be added last for proper FlatBuffer layout + FBIntOrVidModule.AddVid(builder, CreateVid(builder, iov.vid.idx)) + return FBIntOrVidModule.End(builder) + + def _build_float_or_vid(self, builder: flatbuffers.Builder, fov: FloatOrVid) -> int: + """Build a FloatOrVid table.""" + from executorch.backends.mlx.serialization._generated.mlx_delegate import FloatOrVid as FBFloatOrVidModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBFloatOrVidModule.Start(builder) + FBFloatOrVidModule.AddLiteral(builder, fov.literal) + FBFloatOrVidModule.AddIsVid(builder, fov.is_vid) + if fov.vid is not None: + FBFloatOrVidModule.AddVid(builder, CreateVid(builder, fov.vid.idx)) + return FBFloatOrVidModule.End(builder) + + def _build_vid_or_tid(self, builder: flatbuffers.Builder, vot: VidOrTid) -> int: + """Build a TidOrVid table.""" + from executorch.backends.mlx.serialization._generated.mlx_delegate import VidOrTid as FBVidOrTidModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBVidOrTidModule.Start(builder) + FBVidOrTidModule.AddIsVid(builder, vot.is_vid) + if vot.tid is not None: + FBVidOrTidModule.AddTid(builder, CreateTid(builder, vot.tid.idx)) + if vot.vid is not None: + FBVidOrTidModule.AddVid(builder, CreateVid(builder, vot.vid.idx)) + return FBVidOrTidModule.End(builder) + + def _build_int_or_vid_or_tid(self, builder: flatbuffers.Builder, ivt: IntOrVidOrTid) -> int: + """Build an IntOrVidOrTid table.""" + from executorch.backends.mlx.serialization._generated.mlx_delegate import IntOrVidOrTid as FBIntOrVidOrTidModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBIntOrVidOrTidModule.Start(builder) + FBIntOrVidOrTidModule.AddLiteral(builder, ivt.literal) + FBIntOrVidOrTidModule.AddKind(builder, ivt.kind) + if ivt.tid is not None: + FBIntOrVidOrTidModule.AddTid(builder, CreateTid(builder, ivt.tid.idx)) + if ivt.vid is not None: + FBIntOrVidOrTidModule.AddVid(builder, CreateVid(builder, ivt.vid.idx)) + return FBIntOrVidOrTidModule.End(builder) + + def _build_int_or_vid_vector( + self, builder: flatbuffers.Builder, vec: List[IntOrVid] + ) -> int: + """Build a vector of IntOrVid tables.""" + offsets = [] + for iov in vec: + offsets.append(self._build_int_or_vid(builder, iov)) + builder.StartVector(4, len(offsets), 4) + for off in reversed(offsets): + builder.PrependUOffsetTRelative(off) + return builder.EndVector() + + def _build_tid_vector( + self, builder: flatbuffers.Builder, vec: List[Tid] + ) -> int: + """Build a vector of Tid structs.""" + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + + # For vectors of structs, we need to build the vector differently + # Each Tid struct is 4 bytes (uint32), so we manually write them + builder.StartVector(4, len(vec), 4) + for tid in reversed(vec): + builder.Prep(4, 0) # Align for struct + builder.PrependUint32(tid.idx) + return builder.EndVector() + + def _build_string_vector( + self, builder: flatbuffers.Builder, vec: List[str] + ) -> int: + """Pre-build a vector of strings (offsets must be created before table Start).""" + offsets = [builder.CreateString(s) for s in vec] + builder.StartVector(4, len(offsets), 4) + for off in reversed(offsets): + builder.PrependUOffsetTRelative(off) + return builder.EndVector() + + def _build_NoopNode( + self, builder: flatbuffers.Builder, op: NoopNode + ) -> Tuple[int, int]: + """Auto-generated builder for NoopNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import NoopNode as FBNoopNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBNoopNodeModule.Start(builder) + offset = FBNoopNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.NoopNode + + def _build_IdCopyNode( + self, builder: flatbuffers.Builder, op: IdCopyNode + ) -> Tuple[int, int]: + """Auto-generated builder for IdCopyNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import IdCopyNode as FBIdCopyNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBIdCopyNodeModule.Start(builder) + FBIdCopyNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBIdCopyNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBIdCopyNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.IdCopyNode + + def _build_AddmmNode( + self, builder: flatbuffers.Builder, op: AddmmNode + ) -> Tuple[int, int]: + """Auto-generated builder for AddmmNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import AddmmNode as FBAddmmNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBAddmmNodeModule.Start(builder) + FBAddmmNodeModule.AddMat1(builder, CreateTid(builder, op.mat1.idx)) + FBAddmmNodeModule.AddMat2(builder, CreateTid(builder, op.mat2.idx)) + FBAddmmNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if op.bias is not None: + FBAddmmNodeModule.AddBias(builder, CreateTid(builder, op.bias.idx)) + FBAddmmNodeModule.AddAlpha(builder, op.alpha) + FBAddmmNodeModule.AddBeta(builder, op.beta) + offset = FBAddmmNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.AddmmNode + + def _build_ItemIntNode( + self, builder: flatbuffers.Builder, op: ItemIntNode + ) -> Tuple[int, int]: + """Auto-generated builder for ItemIntNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ItemIntNode as FBItemIntNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBItemIntNodeModule.Start(builder) + FBItemIntNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBItemIntNodeModule.AddOut(builder, CreateVid(builder, op.out.idx)) + offset = FBItemIntNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ItemIntNode + + def _build_ExpandDimsNode( + self, builder: flatbuffers.Builder, op: ExpandDimsNode + ) -> Tuple[int, int]: + """Auto-generated builder for ExpandDimsNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ExpandDimsNode as FBExpandDimsNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBExpandDimsNodeModule.Start(builder) + FBExpandDimsNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBExpandDimsNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBExpandDimsNodeModule.AddAxis(builder, op.axis) + offset = FBExpandDimsNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ExpandDimsNode + + def _build_TileNode( + self, builder: flatbuffers.Builder, op: TileNode + ) -> Tuple[int, int]: + """Auto-generated builder for TileNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import TileNode as FBTileNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + reps_vec = self._build_int_or_vid_vector(builder, op.reps) + + FBTileNodeModule.Start(builder) + FBTileNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBTileNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBTileNodeModule.AddReps(builder, reps_vec) + offset = FBTileNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.TileNode + + def _build_TakeAlongAxisNode( + self, builder: flatbuffers.Builder, op: TakeAlongAxisNode + ) -> Tuple[int, int]: + """Auto-generated builder for TakeAlongAxisNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import TakeAlongAxisNode as FBTakeAlongAxisNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBTakeAlongAxisNodeModule.Start(builder) + FBTakeAlongAxisNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBTakeAlongAxisNodeModule.AddIndices(builder, CreateTid(builder, op.indices.idx)) + FBTakeAlongAxisNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBTakeAlongAxisNodeModule.AddAxis(builder, op.axis) + offset = FBTakeAlongAxisNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.TakeAlongAxisNode + + def _build_TakeNode( + self, builder: flatbuffers.Builder, op: TakeNode + ) -> Tuple[int, int]: + """Auto-generated builder for TakeNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import TakeNode as FBTakeNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + index_off = self._build_int_or_vid_or_tid(builder, op.index) + + FBTakeNodeModule.Start(builder) + FBTakeNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBTakeNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBTakeNodeModule.AddIndex(builder, index_off) + FBTakeNodeModule.AddAxis(builder, op.axis) + offset = FBTakeNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.TakeNode + + def _build_RMSNormNode( + self, builder: flatbuffers.Builder, op: RMSNormNode + ) -> Tuple[int, int]: + """Auto-generated builder for RMSNormNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import RMSNormNode as FBRMSNormNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBRMSNormNodeModule.Start(builder) + FBRMSNormNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + if op.weight is not None: + FBRMSNormNodeModule.AddWeight(builder, CreateTid(builder, op.weight.idx)) + FBRMSNormNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBRMSNormNodeModule.AddEps(builder, op.eps) + offset = FBRMSNormNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.RMSNormNode + + def _build_LayerNormNode( + self, builder: flatbuffers.Builder, op: LayerNormNode + ) -> Tuple[int, int]: + """Auto-generated builder for LayerNormNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import LayerNormNode as FBLayerNormNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBLayerNormNodeModule.Start(builder) + FBLayerNormNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBLayerNormNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if op.weight is not None: + FBLayerNormNodeModule.AddWeight(builder, CreateTid(builder, op.weight.idx)) + if op.bias is not None: + FBLayerNormNodeModule.AddBias(builder, CreateTid(builder, op.bias.idx)) + FBLayerNormNodeModule.AddEps(builder, op.eps) + offset = FBLayerNormNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.LayerNormNode + + def _build_RopeNode( + self, builder: flatbuffers.Builder, op: RopeNode + ) -> Tuple[int, int]: + """Auto-generated builder for RopeNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import RopeNode as FBRopeNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + offset_off = self._build_vid_or_tid(builder, op.offset) + + FBRopeNodeModule.Start(builder) + FBRopeNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBRopeNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBRopeNodeModule.AddDims(builder, op.dims) + FBRopeNodeModule.AddOffset(builder, offset_off) + if op.freqs is not None: + FBRopeNodeModule.AddFreqs(builder, CreateTid(builder, op.freqs.idx)) + FBRopeNodeModule.AddTraditional(builder, op.traditional) + FBRopeNodeModule.AddBase(builder, op.base) + FBRopeNodeModule.AddScale(builder, op.scale) + offset = FBRopeNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.RopeNode + + def _build_SdpaNode( + self, builder: flatbuffers.Builder, op: SdpaNode + ) -> Tuple[int, int]: + """Auto-generated builder for SdpaNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SdpaNode as FBSdpaNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBSdpaNodeModule.Start(builder) + FBSdpaNodeModule.AddQ(builder, CreateTid(builder, op.q.idx)) + FBSdpaNodeModule.AddK(builder, CreateTid(builder, op.k.idx)) + FBSdpaNodeModule.AddV(builder, CreateTid(builder, op.v.idx)) + FBSdpaNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBSdpaNodeModule.AddScale(builder, op.scale) + if op.mask is not None: + FBSdpaNodeModule.AddMask(builder, CreateTid(builder, op.mask.idx)) + FBSdpaNodeModule.AddCausal(builder, op.causal) + offset = FBSdpaNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SdpaNode + + def _build_AddNode( + self, builder: flatbuffers.Builder, op: AddNode + ) -> Tuple[int, int]: + """Auto-generated builder for AddNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import AddNode as FBAddNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBAddNodeModule.Start(builder) + FBAddNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBAddNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBAddNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBAddNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.AddNode + + def _build_AddIntNode( + self, builder: flatbuffers.Builder, op: AddIntNode + ) -> Tuple[int, int]: + """Auto-generated builder for AddIntNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import AddIntNode as FBAddIntNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + a_off = self._build_int_or_vid(builder, op.a) + b_off = self._build_int_or_vid(builder, op.b) + + FBAddIntNodeModule.Start(builder) + FBAddIntNodeModule.AddA(builder, a_off) + FBAddIntNodeModule.AddB(builder, b_off) + FBAddIntNodeModule.AddOut(builder, CreateVid(builder, op.out.idx)) + offset = FBAddIntNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.AddIntNode + + def _build_SubtractIntNode( + self, builder: flatbuffers.Builder, op: SubtractIntNode + ) -> Tuple[int, int]: + """Auto-generated builder for SubtractIntNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SubtractIntNode as FBSubtractIntNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + a_off = self._build_int_or_vid(builder, op.a) + b_off = self._build_int_or_vid(builder, op.b) + + FBSubtractIntNodeModule.Start(builder) + FBSubtractIntNodeModule.AddA(builder, a_off) + FBSubtractIntNodeModule.AddB(builder, b_off) + FBSubtractIntNodeModule.AddOut(builder, CreateVid(builder, op.out.idx)) + offset = FBSubtractIntNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SubtractIntNode + + def _build_MultiplyIntNode( + self, builder: flatbuffers.Builder, op: MultiplyIntNode + ) -> Tuple[int, int]: + """Auto-generated builder for MultiplyIntNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import MultiplyIntNode as FBMultiplyIntNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + a_off = self._build_int_or_vid(builder, op.a) + b_off = self._build_int_or_vid(builder, op.b) + + FBMultiplyIntNodeModule.Start(builder) + FBMultiplyIntNodeModule.AddA(builder, a_off) + FBMultiplyIntNodeModule.AddB(builder, b_off) + FBMultiplyIntNodeModule.AddOut(builder, CreateVid(builder, op.out.idx)) + offset = FBMultiplyIntNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.MultiplyIntNode + + def _build_FloorDivideIntNode( + self, builder: flatbuffers.Builder, op: FloorDivideIntNode + ) -> Tuple[int, int]: + """Auto-generated builder for FloorDivideIntNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import FloorDivideIntNode as FBFloorDivideIntNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + a_off = self._build_int_or_vid(builder, op.a) + b_off = self._build_int_or_vid(builder, op.b) + + FBFloorDivideIntNodeModule.Start(builder) + FBFloorDivideIntNodeModule.AddA(builder, a_off) + FBFloorDivideIntNodeModule.AddB(builder, b_off) + FBFloorDivideIntNodeModule.AddOut(builder, CreateVid(builder, op.out.idx)) + offset = FBFloorDivideIntNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.FloorDivideIntNode + + def _build_ModIntNode( + self, builder: flatbuffers.Builder, op: ModIntNode + ) -> Tuple[int, int]: + """Auto-generated builder for ModIntNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ModIntNode as FBModIntNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + a_off = self._build_int_or_vid(builder, op.a) + b_off = self._build_int_or_vid(builder, op.b) + + FBModIntNodeModule.Start(builder) + FBModIntNodeModule.AddA(builder, a_off) + FBModIntNodeModule.AddB(builder, b_off) + FBModIntNodeModule.AddOut(builder, CreateVid(builder, op.out.idx)) + offset = FBModIntNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ModIntNode + + def _build_SymSizeNode( + self, builder: flatbuffers.Builder, op: SymSizeNode + ) -> Tuple[int, int]: + """Auto-generated builder for SymSizeNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SymSizeNode as FBSymSizeNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBSymSizeNodeModule.Start(builder) + FBSymSizeNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBSymSizeNodeModule.AddDim(builder, op.dim) + FBSymSizeNodeModule.AddOut(builder, CreateVid(builder, op.out.idx)) + offset = FBSymSizeNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SymSizeNode + + def _build_MultiplyNode( + self, builder: flatbuffers.Builder, op: MultiplyNode + ) -> Tuple[int, int]: + """Auto-generated builder for MultiplyNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import MultiplyNode as FBMultiplyNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBMultiplyNodeModule.Start(builder) + FBMultiplyNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBMultiplyNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBMultiplyNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBMultiplyNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.MultiplyNode + + def _build_DivideNode( + self, builder: flatbuffers.Builder, op: DivideNode + ) -> Tuple[int, int]: + """Auto-generated builder for DivideNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import DivideNode as FBDivideNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBDivideNodeModule.Start(builder) + FBDivideNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBDivideNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBDivideNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBDivideNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.DivideNode + + def _build_SubtractNode( + self, builder: flatbuffers.Builder, op: SubtractNode + ) -> Tuple[int, int]: + """Auto-generated builder for SubtractNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SubtractNode as FBSubtractNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBSubtractNodeModule.Start(builder) + FBSubtractNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBSubtractNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBSubtractNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBSubtractNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SubtractNode + + def _build_Conv1DNode( + self, builder: flatbuffers.Builder, op: Conv1DNode + ) -> Tuple[int, int]: + """Auto-generated builder for Conv1DNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import Conv1DNode as FBConv1DNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBConv1DNodeModule.Start(builder) + FBConv1DNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBConv1DNodeModule.AddW(builder, CreateTid(builder, op.w.idx)) + FBConv1DNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBConv1DNodeModule.AddStride(builder, op.stride) + FBConv1DNodeModule.AddPadding(builder, op.padding) + FBConv1DNodeModule.AddDilation(builder, op.dilation) + FBConv1DNodeModule.AddGroups(builder, op.groups) + offset = FBConv1DNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.Conv1DNode + + def _build_Conv2DNode( + self, builder: flatbuffers.Builder, op: Conv2DNode + ) -> Tuple[int, int]: + """Auto-generated builder for Conv2DNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import Conv2DNode as FBConv2DNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBConv2DNodeModule.Start(builder) + FBConv2DNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBConv2DNodeModule.AddW(builder, CreateTid(builder, op.w.idx)) + FBConv2DNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBConv2DNodeModule.AddStrideH(builder, op.stride_h) + FBConv2DNodeModule.AddStrideW(builder, op.stride_w) + FBConv2DNodeModule.AddPaddingH(builder, op.padding_h) + FBConv2DNodeModule.AddPaddingW(builder, op.padding_w) + FBConv2DNodeModule.AddDilationH(builder, op.dilation_h) + FBConv2DNodeModule.AddDilationW(builder, op.dilation_w) + FBConv2DNodeModule.AddGroups(builder, op.groups) + offset = FBConv2DNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.Conv2DNode + + def _build_Conv3DNode( + self, builder: flatbuffers.Builder, op: Conv3DNode + ) -> Tuple[int, int]: + """Auto-generated builder for Conv3DNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import Conv3DNode as FBConv3DNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBConv3DNodeModule.Start(builder) + FBConv3DNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBConv3DNodeModule.AddW(builder, CreateTid(builder, op.w.idx)) + FBConv3DNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBConv3DNodeModule.AddStrideD(builder, op.stride_d) + FBConv3DNodeModule.AddStrideH(builder, op.stride_h) + FBConv3DNodeModule.AddStrideW(builder, op.stride_w) + FBConv3DNodeModule.AddPaddingD(builder, op.padding_d) + FBConv3DNodeModule.AddPaddingH(builder, op.padding_h) + FBConv3DNodeModule.AddPaddingW(builder, op.padding_w) + FBConv3DNodeModule.AddDilationD(builder, op.dilation_d) + FBConv3DNodeModule.AddDilationH(builder, op.dilation_h) + FBConv3DNodeModule.AddDilationW(builder, op.dilation_w) + FBConv3DNodeModule.AddGroups(builder, op.groups) + offset = FBConv3DNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.Conv3DNode + + def _build_ConvTranspose1DNode( + self, builder: flatbuffers.Builder, op: ConvTranspose1DNode + ) -> Tuple[int, int]: + """Auto-generated builder for ConvTranspose1DNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ConvTranspose1DNode as FBConvTranspose1DNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBConvTranspose1DNodeModule.Start(builder) + FBConvTranspose1DNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBConvTranspose1DNodeModule.AddW(builder, CreateTid(builder, op.w.idx)) + FBConvTranspose1DNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBConvTranspose1DNodeModule.AddStride(builder, op.stride) + FBConvTranspose1DNodeModule.AddPadding(builder, op.padding) + FBConvTranspose1DNodeModule.AddDilation(builder, op.dilation) + FBConvTranspose1DNodeModule.AddOutputPadding(builder, op.output_padding) + FBConvTranspose1DNodeModule.AddGroups(builder, op.groups) + offset = FBConvTranspose1DNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ConvTranspose1DNode + + def _build_ConvTranspose2DNode( + self, builder: flatbuffers.Builder, op: ConvTranspose2DNode + ) -> Tuple[int, int]: + """Auto-generated builder for ConvTranspose2DNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ConvTranspose2DNode as FBConvTranspose2DNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBConvTranspose2DNodeModule.Start(builder) + FBConvTranspose2DNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBConvTranspose2DNodeModule.AddW(builder, CreateTid(builder, op.w.idx)) + FBConvTranspose2DNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBConvTranspose2DNodeModule.AddStrideH(builder, op.stride_h) + FBConvTranspose2DNodeModule.AddStrideW(builder, op.stride_w) + FBConvTranspose2DNodeModule.AddPaddingH(builder, op.padding_h) + FBConvTranspose2DNodeModule.AddPaddingW(builder, op.padding_w) + FBConvTranspose2DNodeModule.AddDilationH(builder, op.dilation_h) + FBConvTranspose2DNodeModule.AddDilationW(builder, op.dilation_w) + FBConvTranspose2DNodeModule.AddOutputPaddingH(builder, op.output_padding_h) + FBConvTranspose2DNodeModule.AddOutputPaddingW(builder, op.output_padding_w) + FBConvTranspose2DNodeModule.AddGroups(builder, op.groups) + offset = FBConvTranspose2DNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ConvTranspose2DNode + + def _build_ConvTranspose3DNode( + self, builder: flatbuffers.Builder, op: ConvTranspose3DNode + ) -> Tuple[int, int]: + """Auto-generated builder for ConvTranspose3DNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ConvTranspose3DNode as FBConvTranspose3DNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBConvTranspose3DNodeModule.Start(builder) + FBConvTranspose3DNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBConvTranspose3DNodeModule.AddW(builder, CreateTid(builder, op.w.idx)) + FBConvTranspose3DNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBConvTranspose3DNodeModule.AddStrideD(builder, op.stride_d) + FBConvTranspose3DNodeModule.AddStrideH(builder, op.stride_h) + FBConvTranspose3DNodeModule.AddStrideW(builder, op.stride_w) + FBConvTranspose3DNodeModule.AddPaddingD(builder, op.padding_d) + FBConvTranspose3DNodeModule.AddPaddingH(builder, op.padding_h) + FBConvTranspose3DNodeModule.AddPaddingW(builder, op.padding_w) + FBConvTranspose3DNodeModule.AddDilationD(builder, op.dilation_d) + FBConvTranspose3DNodeModule.AddDilationH(builder, op.dilation_h) + FBConvTranspose3DNodeModule.AddDilationW(builder, op.dilation_w) + FBConvTranspose3DNodeModule.AddOutputPaddingD(builder, op.output_padding_d) + FBConvTranspose3DNodeModule.AddOutputPaddingH(builder, op.output_padding_h) + FBConvTranspose3DNodeModule.AddOutputPaddingW(builder, op.output_padding_w) + FBConvTranspose3DNodeModule.AddGroups(builder, op.groups) + offset = FBConvTranspose3DNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ConvTranspose3DNode + + def _build_GeluNode( + self, builder: flatbuffers.Builder, op: GeluNode + ) -> Tuple[int, int]: + """Auto-generated builder for GeluNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import GeluNode as FBGeluNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + approximate_off = builder.CreateString(op.approximate) + + FBGeluNodeModule.Start(builder) + FBGeluNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBGeluNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBGeluNodeModule.AddApproximate(builder, approximate_off) + offset = FBGeluNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.GeluNode + + def _build_ARangeNode( + self, builder: flatbuffers.Builder, op: ARangeNode + ) -> Tuple[int, int]: + """Auto-generated builder for ARangeNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ARangeNode as FBARangeNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + start_off = self._build_int_or_vid(builder, op.start) + stop_off = self._build_int_or_vid(builder, op.stop) + step_off = self._build_int_or_vid(builder, op.step) + + FBARangeNodeModule.Start(builder) + FBARangeNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBARangeNodeModule.AddStart(builder, start_off) + FBARangeNodeModule.AddStop(builder, stop_off) + FBARangeNodeModule.AddStep(builder, step_off) + if op.scalar_type is not None: + FBARangeNodeModule.AddScalarType(builder, op.scalar_type) + offset = FBARangeNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ARangeNode + + def _build_SiluNode( + self, builder: flatbuffers.Builder, op: SiluNode + ) -> Tuple[int, int]: + """Auto-generated builder for SiluNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SiluNode as FBSiluNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBSiluNodeModule.Start(builder) + FBSiluNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBSiluNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBSiluNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SiluNode + + def _build_SigmoidNode( + self, builder: flatbuffers.Builder, op: SigmoidNode + ) -> Tuple[int, int]: + """Auto-generated builder for SigmoidNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SigmoidNode as FBSigmoidNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBSigmoidNodeModule.Start(builder) + FBSigmoidNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBSigmoidNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBSigmoidNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SigmoidNode + + def _build_TanhNode( + self, builder: flatbuffers.Builder, op: TanhNode + ) -> Tuple[int, int]: + """Auto-generated builder for TanhNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import TanhNode as FBTanhNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBTanhNodeModule.Start(builder) + FBTanhNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBTanhNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBTanhNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.TanhNode + + def _build_SqueezeNode( + self, builder: flatbuffers.Builder, op: SqueezeNode + ) -> Tuple[int, int]: + """Auto-generated builder for SqueezeNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SqueezeNode as FBSqueezeNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + dims_vec = _build_int_vector(builder, op.dims) if op.dims is not None else None + + FBSqueezeNodeModule.Start(builder) + FBSqueezeNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBSqueezeNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if dims_vec is not None: + FBSqueezeNodeModule.AddDims(builder, dims_vec) + offset = FBSqueezeNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SqueezeNode + + def _build_SplitNode( + self, builder: flatbuffers.Builder, op: SplitNode + ) -> Tuple[int, int]: + """Auto-generated builder for SplitNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SplitNode as FBSplitNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + outs_vec = self._build_tid_vector(builder, op.outs) + sizes_vec = self._build_int_or_vid_vector(builder, op.sizes) + + FBSplitNodeModule.Start(builder) + FBSplitNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBSplitNodeModule.AddOuts(builder, outs_vec) + FBSplitNodeModule.AddSizes(builder, sizes_vec) + FBSplitNodeModule.AddAxis(builder, op.axis) + offset = FBSplitNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SplitNode + + def _build_RsqrtNode( + self, builder: flatbuffers.Builder, op: RsqrtNode + ) -> Tuple[int, int]: + """Auto-generated builder for RsqrtNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import RsqrtNode as FBRsqrtNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBRsqrtNodeModule.Start(builder) + FBRsqrtNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBRsqrtNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBRsqrtNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.RsqrtNode + + def _build_MaximumNode( + self, builder: flatbuffers.Builder, op: MaximumNode + ) -> Tuple[int, int]: + """Auto-generated builder for MaximumNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import MaximumNode as FBMaximumNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBMaximumNodeModule.Start(builder) + FBMaximumNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBMaximumNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBMaximumNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBMaximumNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.MaximumNode + + def _build_MinimumNode( + self, builder: flatbuffers.Builder, op: MinimumNode + ) -> Tuple[int, int]: + """Auto-generated builder for MinimumNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import MinimumNode as FBMinimumNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBMinimumNodeModule.Start(builder) + FBMinimumNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBMinimumNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBMinimumNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBMinimumNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.MinimumNode + + def _build_LogNode( + self, builder: flatbuffers.Builder, op: LogNode + ) -> Tuple[int, int]: + """Auto-generated builder for LogNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import LogNode as FBLogNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBLogNodeModule.Start(builder) + FBLogNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBLogNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBLogNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.LogNode + + def _build_SoftmaxNode( + self, builder: flatbuffers.Builder, op: SoftmaxNode + ) -> Tuple[int, int]: + """Auto-generated builder for SoftmaxNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SoftmaxNode as FBSoftmaxNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBSoftmaxNodeModule.Start(builder) + FBSoftmaxNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBSoftmaxNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBSoftmaxNodeModule.AddAxis(builder, op.axis) + FBSoftmaxNodeModule.AddPrecise(builder, op.precise) + offset = FBSoftmaxNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SoftmaxNode + + def _build_BroadcastToNode( + self, builder: flatbuffers.Builder, op: BroadcastToNode + ) -> Tuple[int, int]: + """Auto-generated builder for BroadcastToNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import BroadcastToNode as FBBroadcastToNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + shape_vec = self._build_int_or_vid_vector(builder, op.shape) + + FBBroadcastToNodeModule.Start(builder) + FBBroadcastToNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBBroadcastToNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBBroadcastToNodeModule.AddShape(builder, shape_vec) + offset = FBBroadcastToNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.BroadcastToNode + + def _build_PadNode( + self, builder: flatbuffers.Builder, op: PadNode + ) -> Tuple[int, int]: + """Auto-generated builder for PadNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import PadNode as FBPadNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + pad_width_vec = self._build_int_or_vid_vector(builder, op.pad_width) + mode_off = builder.CreateString(op.mode) + + FBPadNodeModule.Start(builder) + FBPadNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBPadNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBPadNodeModule.AddPadWidth(builder, pad_width_vec) + FBPadNodeModule.AddMode(builder, mode_off) + FBPadNodeModule.AddConstantValue(builder, op.constant_value) + offset = FBPadNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.PadNode + + def _build_WhereNode( + self, builder: flatbuffers.Builder, op: WhereNode + ) -> Tuple[int, int]: + """Auto-generated builder for WhereNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import WhereNode as FBWhereNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBWhereNodeModule.Start(builder) + FBWhereNodeModule.AddCondition(builder, CreateTid(builder, op.condition.idx)) + FBWhereNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBWhereNodeModule.AddY(builder, CreateTid(builder, op.y.idx)) + FBWhereNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBWhereNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.WhereNode + + def _build_ReshapeNode( + self, builder: flatbuffers.Builder, op: ReshapeNode + ) -> Tuple[int, int]: + """Auto-generated builder for ReshapeNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ReshapeNode as FBReshapeNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + shape_vec = self._build_int_or_vid_vector(builder, op.shape) + + FBReshapeNodeModule.Start(builder) + FBReshapeNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBReshapeNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBReshapeNodeModule.AddShape(builder, shape_vec) + offset = FBReshapeNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ReshapeNode + + def _build_TransposeNode( + self, builder: flatbuffers.Builder, op: TransposeNode + ) -> Tuple[int, int]: + """Auto-generated builder for TransposeNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import TransposeNode as FBTransposeNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + perm_vec = _build_int_vector(builder, op.perm) + + FBTransposeNodeModule.Start(builder) + FBTransposeNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBTransposeNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBTransposeNodeModule.AddPerm(builder, perm_vec) + offset = FBTransposeNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.TransposeNode + + def _build_AsStridedNode( + self, builder: flatbuffers.Builder, op: AsStridedNode + ) -> Tuple[int, int]: + """Auto-generated builder for AsStridedNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import AsStridedNode as FBAsStridedNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + shape_vec = self._build_int_or_vid_vector(builder, op.shape) + strides_vec = self._build_int_or_vid_vector(builder, op.strides) + + FBAsStridedNodeModule.Start(builder) + FBAsStridedNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBAsStridedNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBAsStridedNodeModule.AddShape(builder, shape_vec) + FBAsStridedNodeModule.AddStrides(builder, strides_vec) + FBAsStridedNodeModule.AddOffset(builder, op.offset) + offset = FBAsStridedNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.AsStridedNode + + def _build_ContiguousNode( + self, builder: flatbuffers.Builder, op: ContiguousNode + ) -> Tuple[int, int]: + """Auto-generated builder for ContiguousNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ContiguousNode as FBContiguousNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBContiguousNodeModule.Start(builder) + FBContiguousNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBContiguousNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBContiguousNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ContiguousNode + + def _build_GatherNode( + self, builder: flatbuffers.Builder, op: GatherNode + ) -> Tuple[int, int]: + """Auto-generated builder for GatherNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import GatherNode as FBGatherNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + indices_vec = self._build_tid_vector(builder, op.indices) + axes_vec = _build_int_vector(builder, op.axes) + slice_sizes_vec = _build_int_vector(builder, op.slice_sizes) + + FBGatherNodeModule.Start(builder) + FBGatherNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBGatherNodeModule.AddIndices(builder, indices_vec) + FBGatherNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBGatherNodeModule.AddAxes(builder, axes_vec) + FBGatherNodeModule.AddSliceSizes(builder, slice_sizes_vec) + offset = FBGatherNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.GatherNode + + def _build_SliceNode( + self, builder: flatbuffers.Builder, op: SliceNode + ) -> Tuple[int, int]: + """Auto-generated builder for SliceNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SliceNode as FBSliceNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + axis_off = self._build_int_or_vid(builder, op.axis) + start_off = self._build_int_or_vid(builder, op.start) + stop_off = self._build_int_or_vid(builder, op.stop) + + FBSliceNodeModule.Start(builder) + FBSliceNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBSliceNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBSliceNodeModule.AddAxis(builder, axis_off) + FBSliceNodeModule.AddStart(builder, start_off) + FBSliceNodeModule.AddStop(builder, stop_off) + FBSliceNodeModule.AddStep(builder, op.step) + offset = FBSliceNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SliceNode + + def _build_AsTypeNode( + self, builder: flatbuffers.Builder, op: AsTypeNode + ) -> Tuple[int, int]: + """Auto-generated builder for AsTypeNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import AsTypeNode as FBAsTypeNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBAsTypeNodeModule.Start(builder) + FBAsTypeNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBAsTypeNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBAsTypeNodeModule.AddScalarType(builder, op.scalar_type) + offset = FBAsTypeNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.AsTypeNode + + def _build_QuantizedMatmulNode( + self, builder: flatbuffers.Builder, op: QuantizedMatmulNode + ) -> Tuple[int, int]: + """Auto-generated builder for QuantizedMatmulNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import QuantizedMatmulNode as FBQuantizedMatmulNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + mode_off = builder.CreateString(op.mode) + + FBQuantizedMatmulNodeModule.Start(builder) + FBQuantizedMatmulNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBQuantizedMatmulNodeModule.AddW(builder, CreateTid(builder, op.w.idx)) + FBQuantizedMatmulNodeModule.AddScales(builder, CreateTid(builder, op.scales.idx)) + FBQuantizedMatmulNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if op.biases is not None: + FBQuantizedMatmulNodeModule.AddBiases(builder, CreateTid(builder, op.biases.idx)) + FBQuantizedMatmulNodeModule.AddGroupSize(builder, op.group_size) + FBQuantizedMatmulNodeModule.AddBits(builder, op.bits) + FBQuantizedMatmulNodeModule.AddMode(builder, mode_off) + FBQuantizedMatmulNodeModule.AddTranspose(builder, op.transpose) + offset = FBQuantizedMatmulNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.QuantizedMatmulNode + + def _build_ScatterAddNode( + self, builder: flatbuffers.Builder, op: ScatterAddNode + ) -> Tuple[int, int]: + """Auto-generated builder for ScatterAddNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ScatterAddNode as FBScatterAddNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBScatterAddNodeModule.Start(builder) + FBScatterAddNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBScatterAddNodeModule.AddIndices(builder, CreateTid(builder, op.indices.idx)) + FBScatterAddNodeModule.AddUpdates(builder, CreateTid(builder, op.updates.idx)) + FBScatterAddNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBScatterAddNodeModule.AddAxis(builder, op.axis) + offset = FBScatterAddNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ScatterAddNode + + def _build_ConcatenateNode( + self, builder: flatbuffers.Builder, op: ConcatenateNode + ) -> Tuple[int, int]: + """Auto-generated builder for ConcatenateNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ConcatenateNode as FBConcatenateNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + tensors_vec = self._build_tid_vector(builder, op.tensors) + + FBConcatenateNodeModule.Start(builder) + FBConcatenateNodeModule.AddTensors(builder, tensors_vec) + FBConcatenateNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBConcatenateNodeModule.AddAxis(builder, op.axis) + offset = FBConcatenateNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ConcatenateNode + + def _build_FullNode( + self, builder: flatbuffers.Builder, op: FullNode + ) -> Tuple[int, int]: + """Auto-generated builder for FullNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import FullNode as FBFullNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + shape_vec = self._build_int_or_vid_vector(builder, op.shape) + v_off = self._build_float_or_vid(builder, op.v) + + FBFullNodeModule.Start(builder) + FBFullNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBFullNodeModule.AddShape(builder, shape_vec) + FBFullNodeModule.AddV(builder, v_off) + FBFullNodeModule.AddScalarType(builder, op.scalar_type) + offset = FBFullNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.FullNode + + def _build_FullLikeNode( + self, builder: flatbuffers.Builder, op: FullLikeNode + ) -> Tuple[int, int]: + """Auto-generated builder for FullLikeNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import FullLikeNode as FBFullLikeNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + v_off = self._build_float_or_vid(builder, op.v) + + FBFullLikeNodeModule.Start(builder) + FBFullLikeNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBFullLikeNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBFullLikeNodeModule.AddV(builder, v_off) + if op.scalar_type is not None: + FBFullLikeNodeModule.AddScalarType(builder, op.scalar_type) + offset = FBFullLikeNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.FullLikeNode + + def _build_ArgmaxNode( + self, builder: flatbuffers.Builder, op: ArgmaxNode + ) -> Tuple[int, int]: + """Auto-generated builder for ArgmaxNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ArgmaxNode as FBArgmaxNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBArgmaxNodeModule.Start(builder) + FBArgmaxNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBArgmaxNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBArgmaxNodeModule.AddAxis(builder, op.axis) + FBArgmaxNodeModule.AddKeepdims(builder, op.keepdims) + offset = FBArgmaxNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ArgmaxNode + + def _build_SliceUpdateNode( + self, builder: flatbuffers.Builder, op: SliceUpdateNode + ) -> Tuple[int, int]: + """Auto-generated builder for SliceUpdateNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SliceUpdateNode as FBSliceUpdateNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + axis_off = self._build_int_or_vid(builder, op.axis) + start_off = self._build_int_or_vid(builder, op.start) + stop_off = self._build_int_or_vid(builder, op.stop) + + FBSliceUpdateNodeModule.Start(builder) + FBSliceUpdateNodeModule.AddDst(builder, CreateTid(builder, op.dst.idx)) + FBSliceUpdateNodeModule.AddUpdate(builder, CreateTid(builder, op.update.idx)) + FBSliceUpdateNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBSliceUpdateNodeModule.AddAxis(builder, axis_off) + FBSliceUpdateNodeModule.AddStart(builder, start_off) + FBSliceUpdateNodeModule.AddStop(builder, stop_off) + FBSliceUpdateNodeModule.AddStep(builder, op.step) + offset = FBSliceUpdateNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SliceUpdateNode + + def _build_IndexCopyNode( + self, builder: flatbuffers.Builder, op: IndexCopyNode + ) -> Tuple[int, int]: + """Auto-generated builder for IndexCopyNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import IndexCopyNode as FBIndexCopyNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBIndexCopyNodeModule.Start(builder) + FBIndexCopyNodeModule.AddDst(builder, CreateTid(builder, op.dst.idx)) + FBIndexCopyNodeModule.AddUpdate(builder, CreateTid(builder, op.update.idx)) + FBIndexCopyNodeModule.AddIndices(builder, CreateTid(builder, op.indices.idx)) + FBIndexCopyNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBIndexCopyNodeModule.AddAxis(builder, op.axis) + offset = FBIndexCopyNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.IndexCopyNode + + def _build_DequantizeNode( + self, builder: flatbuffers.Builder, op: DequantizeNode + ) -> Tuple[int, int]: + """Auto-generated builder for DequantizeNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import DequantizeNode as FBDequantizeNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + mode_off = builder.CreateString(op.mode) + + FBDequantizeNodeModule.Start(builder) + FBDequantizeNodeModule.AddW(builder, CreateTid(builder, op.w.idx)) + FBDequantizeNodeModule.AddScales(builder, CreateTid(builder, op.scales.idx)) + FBDequantizeNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if op.biases is not None: + FBDequantizeNodeModule.AddBiases(builder, CreateTid(builder, op.biases.idx)) + FBDequantizeNodeModule.AddGroupSize(builder, op.group_size) + FBDequantizeNodeModule.AddBits(builder, op.bits) + FBDequantizeNodeModule.AddMode(builder, mode_off) + if op.global_scale is not None: + FBDequantizeNodeModule.AddGlobalScale(builder, CreateTid(builder, op.global_scale.idx)) + if op.dtype is not None: + FBDequantizeNodeModule.AddDtype(builder, op.dtype) + offset = FBDequantizeNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.DequantizeNode + + def _build_LessNode( + self, builder: flatbuffers.Builder, op: LessNode + ) -> Tuple[int, int]: + """Auto-generated builder for LessNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import LessNode as FBLessNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBLessNodeModule.Start(builder) + FBLessNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBLessNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBLessNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBLessNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.LessNode + + def _build_LessEqualNode( + self, builder: flatbuffers.Builder, op: LessEqualNode + ) -> Tuple[int, int]: + """Auto-generated builder for LessEqualNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import LessEqualNode as FBLessEqualNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBLessEqualNodeModule.Start(builder) + FBLessEqualNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBLessEqualNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBLessEqualNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBLessEqualNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.LessEqualNode + + def _build_GreaterNode( + self, builder: flatbuffers.Builder, op: GreaterNode + ) -> Tuple[int, int]: + """Auto-generated builder for GreaterNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import GreaterNode as FBGreaterNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBGreaterNodeModule.Start(builder) + FBGreaterNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBGreaterNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBGreaterNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBGreaterNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.GreaterNode + + def _build_GreaterEqualNode( + self, builder: flatbuffers.Builder, op: GreaterEqualNode + ) -> Tuple[int, int]: + """Auto-generated builder for GreaterEqualNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import GreaterEqualNode as FBGreaterEqualNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBGreaterEqualNodeModule.Start(builder) + FBGreaterEqualNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBGreaterEqualNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBGreaterEqualNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBGreaterEqualNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.GreaterEqualNode + + def _build_EqualNode( + self, builder: flatbuffers.Builder, op: EqualNode + ) -> Tuple[int, int]: + """Auto-generated builder for EqualNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import EqualNode as FBEqualNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBEqualNodeModule.Start(builder) + FBEqualNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBEqualNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBEqualNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBEqualNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.EqualNode + + def _build_NotEqualNode( + self, builder: flatbuffers.Builder, op: NotEqualNode + ) -> Tuple[int, int]: + """Auto-generated builder for NotEqualNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import NotEqualNode as FBNotEqualNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBNotEqualNodeModule.Start(builder) + FBNotEqualNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBNotEqualNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBNotEqualNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBNotEqualNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.NotEqualNode + + def _build_LogicalNotNode( + self, builder: flatbuffers.Builder, op: LogicalNotNode + ) -> Tuple[int, int]: + """Auto-generated builder for LogicalNotNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import LogicalNotNode as FBLogicalNotNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBLogicalNotNodeModule.Start(builder) + FBLogicalNotNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBLogicalNotNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBLogicalNotNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.LogicalNotNode + + def _build_LogicalAndNode( + self, builder: flatbuffers.Builder, op: LogicalAndNode + ) -> Tuple[int, int]: + """Auto-generated builder for LogicalAndNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import LogicalAndNode as FBLogicalAndNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBLogicalAndNodeModule.Start(builder) + FBLogicalAndNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBLogicalAndNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBLogicalAndNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBLogicalAndNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.LogicalAndNode + + def _build_LogicalOrNode( + self, builder: flatbuffers.Builder, op: LogicalOrNode + ) -> Tuple[int, int]: + """Auto-generated builder for LogicalOrNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import LogicalOrNode as FBLogicalOrNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBLogicalOrNodeModule.Start(builder) + FBLogicalOrNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBLogicalOrNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBLogicalOrNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBLogicalOrNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.LogicalOrNode + + def _build_TriNode( + self, builder: flatbuffers.Builder, op: TriNode + ) -> Tuple[int, int]: + """Auto-generated builder for TriNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import TriNode as FBTriNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + n_off = self._build_int_or_vid(builder, op.n) + m_off = self._build_int_or_vid(builder, op.m) + + FBTriNodeModule.Start(builder) + FBTriNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBTriNodeModule.AddN(builder, n_off) + FBTriNodeModule.AddM(builder, m_off) + FBTriNodeModule.AddK(builder, op.k) + FBTriNodeModule.AddScalarType(builder, op.scalar_type) + offset = FBTriNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.TriNode + + def _build_TrilNode( + self, builder: flatbuffers.Builder, op: TrilNode + ) -> Tuple[int, int]: + """Auto-generated builder for TrilNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import TrilNode as FBTrilNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBTrilNodeModule.Start(builder) + FBTrilNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBTrilNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBTrilNodeModule.AddK(builder, op.k) + offset = FBTrilNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.TrilNode + + def _build_TriuNode( + self, builder: flatbuffers.Builder, op: TriuNode + ) -> Tuple[int, int]: + """Auto-generated builder for TriuNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import TriuNode as FBTriuNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBTriuNodeModule.Start(builder) + FBTriuNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBTriuNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBTriuNodeModule.AddK(builder, op.k) + offset = FBTriuNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.TriuNode + + def _build_ClipNode( + self, builder: flatbuffers.Builder, op: ClipNode + ) -> Tuple[int, int]: + """Auto-generated builder for ClipNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ClipNode as FBClipNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBClipNodeModule.Start(builder) + FBClipNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBClipNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if op.a_min is not None: + FBClipNodeModule.AddAMin(builder, CreateTid(builder, op.a_min.idx)) + if op.a_max is not None: + FBClipNodeModule.AddAMax(builder, CreateTid(builder, op.a_max.idx)) + offset = FBClipNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ClipNode + + def _build_CumsumNode( + self, builder: flatbuffers.Builder, op: CumsumNode + ) -> Tuple[int, int]: + """Auto-generated builder for CumsumNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import CumsumNode as FBCumsumNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBCumsumNodeModule.Start(builder) + FBCumsumNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBCumsumNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBCumsumNodeModule.AddAxis(builder, op.axis) + FBCumsumNodeModule.AddReverse(builder, op.reverse) + FBCumsumNodeModule.AddInclusive(builder, op.inclusive) + offset = FBCumsumNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.CumsumNode + + def _build_StackNode( + self, builder: flatbuffers.Builder, op: StackNode + ) -> Tuple[int, int]: + """Auto-generated builder for StackNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import StackNode as FBStackNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + tensors_vec = self._build_tid_vector(builder, op.tensors) + + FBStackNodeModule.Start(builder) + FBStackNodeModule.AddTensors(builder, tensors_vec) + FBStackNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBStackNodeModule.AddAxis(builder, op.axis) + offset = FBStackNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.StackNode + + def _build_SignNode( + self, builder: flatbuffers.Builder, op: SignNode + ) -> Tuple[int, int]: + """Auto-generated builder for SignNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SignNode as FBSignNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBSignNodeModule.Start(builder) + FBSignNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBSignNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBSignNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SignNode + + def _build_AnyNode( + self, builder: flatbuffers.Builder, op: AnyNode + ) -> Tuple[int, int]: + """Auto-generated builder for AnyNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import AnyNode as FBAnyNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + axes_vec = _build_int_vector(builder, op.axes) if op.axes is not None else None + + FBAnyNodeModule.Start(builder) + FBAnyNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBAnyNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if axes_vec is not None: + FBAnyNodeModule.AddAxes(builder, axes_vec) + FBAnyNodeModule.AddKeepdims(builder, op.keepdims) + offset = FBAnyNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.AnyNode + + def _build_AllNode( + self, builder: flatbuffers.Builder, op: AllNode + ) -> Tuple[int, int]: + """Auto-generated builder for AllNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import AllNode as FBAllNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + axes_vec = _build_int_vector(builder, op.axes) if op.axes is not None else None + + FBAllNodeModule.Start(builder) + FBAllNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBAllNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if axes_vec is not None: + FBAllNodeModule.AddAxes(builder, axes_vec) + FBAllNodeModule.AddKeepdims(builder, op.keepdims) + offset = FBAllNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.AllNode + + def _build_RepeatNode( + self, builder: flatbuffers.Builder, op: RepeatNode + ) -> Tuple[int, int]: + """Auto-generated builder for RepeatNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import RepeatNode as FBRepeatNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + repeats_off = self._build_int_or_vid(builder, op.repeats) + + FBRepeatNodeModule.Start(builder) + FBRepeatNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBRepeatNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBRepeatNodeModule.AddRepeats(builder, repeats_off) + FBRepeatNodeModule.AddAxis(builder, op.axis) + offset = FBRepeatNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.RepeatNode + + def _build_SortNode( + self, builder: flatbuffers.Builder, op: SortNode + ) -> Tuple[int, int]: + """Auto-generated builder for SortNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SortNode as FBSortNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBSortNodeModule.Start(builder) + FBSortNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBSortNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBSortNodeModule.AddAxis(builder, op.axis) + offset = FBSortNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SortNode + + def _build_ArgsortNode( + self, builder: flatbuffers.Builder, op: ArgsortNode + ) -> Tuple[int, int]: + """Auto-generated builder for ArgsortNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ArgsortNode as FBArgsortNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBArgsortNodeModule.Start(builder) + FBArgsortNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBArgsortNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBArgsortNodeModule.AddAxis(builder, op.axis) + offset = FBArgsortNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ArgsortNode + + def _build_PartitionNode( + self, builder: flatbuffers.Builder, op: PartitionNode + ) -> Tuple[int, int]: + """Auto-generated builder for PartitionNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import PartitionNode as FBPartitionNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + kth_off = self._build_int_or_vid(builder, op.kth) + + FBPartitionNodeModule.Start(builder) + FBPartitionNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBPartitionNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBPartitionNodeModule.AddKth(builder, kth_off) + FBPartitionNodeModule.AddAxis(builder, op.axis) + offset = FBPartitionNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.PartitionNode + + def _build_ArgPartitionNode( + self, builder: flatbuffers.Builder, op: ArgPartitionNode + ) -> Tuple[int, int]: + """Auto-generated builder for ArgPartitionNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ArgPartitionNode as FBArgPartitionNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + kth_off = self._build_int_or_vid(builder, op.kth) + + FBArgPartitionNodeModule.Start(builder) + FBArgPartitionNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBArgPartitionNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBArgPartitionNodeModule.AddKth(builder, kth_off) + FBArgPartitionNodeModule.AddAxis(builder, op.axis) + offset = FBArgPartitionNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ArgPartitionNode + + def _build_FloorNode( + self, builder: flatbuffers.Builder, op: FloorNode + ) -> Tuple[int, int]: + """Auto-generated builder for FloorNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import FloorNode as FBFloorNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBFloorNodeModule.Start(builder) + FBFloorNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBFloorNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBFloorNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.FloorNode + + def _build_CeilNode( + self, builder: flatbuffers.Builder, op: CeilNode + ) -> Tuple[int, int]: + """Auto-generated builder for CeilNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import CeilNode as FBCeilNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBCeilNodeModule.Start(builder) + FBCeilNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBCeilNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBCeilNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.CeilNode + + def _build_SquareNode( + self, builder: flatbuffers.Builder, op: SquareNode + ) -> Tuple[int, int]: + """Auto-generated builder for SquareNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SquareNode as FBSquareNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBSquareNodeModule.Start(builder) + FBSquareNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBSquareNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBSquareNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SquareNode + + def _build_ExpNode( + self, builder: flatbuffers.Builder, op: ExpNode + ) -> Tuple[int, int]: + """Auto-generated builder for ExpNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ExpNode as FBExpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBExpNodeModule.Start(builder) + FBExpNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBExpNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBExpNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ExpNode + + def _build_SinNode( + self, builder: flatbuffers.Builder, op: SinNode + ) -> Tuple[int, int]: + """Auto-generated builder for SinNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SinNode as FBSinNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBSinNodeModule.Start(builder) + FBSinNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBSinNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBSinNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SinNode + + def _build_CosNode( + self, builder: flatbuffers.Builder, op: CosNode + ) -> Tuple[int, int]: + """Auto-generated builder for CosNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import CosNode as FBCosNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBCosNodeModule.Start(builder) + FBCosNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBCosNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBCosNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.CosNode + + def _build_TanNode( + self, builder: flatbuffers.Builder, op: TanNode + ) -> Tuple[int, int]: + """Auto-generated builder for TanNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import TanNode as FBTanNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBTanNodeModule.Start(builder) + FBTanNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBTanNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBTanNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.TanNode + + def _build_ArcsinNode( + self, builder: flatbuffers.Builder, op: ArcsinNode + ) -> Tuple[int, int]: + """Auto-generated builder for ArcsinNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ArcsinNode as FBArcsinNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBArcsinNodeModule.Start(builder) + FBArcsinNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBArcsinNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBArcsinNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ArcsinNode + + def _build_ArccosNode( + self, builder: flatbuffers.Builder, op: ArccosNode + ) -> Tuple[int, int]: + """Auto-generated builder for ArccosNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ArccosNode as FBArccosNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBArccosNodeModule.Start(builder) + FBArccosNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBArccosNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBArccosNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ArccosNode + + def _build_ArctanNode( + self, builder: flatbuffers.Builder, op: ArctanNode + ) -> Tuple[int, int]: + """Auto-generated builder for ArctanNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ArctanNode as FBArctanNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBArctanNodeModule.Start(builder) + FBArctanNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBArctanNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBArctanNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ArctanNode + + def _build_SinhNode( + self, builder: flatbuffers.Builder, op: SinhNode + ) -> Tuple[int, int]: + """Auto-generated builder for SinhNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SinhNode as FBSinhNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBSinhNodeModule.Start(builder) + FBSinhNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBSinhNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBSinhNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SinhNode + + def _build_CoshNode( + self, builder: flatbuffers.Builder, op: CoshNode + ) -> Tuple[int, int]: + """Auto-generated builder for CoshNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import CoshNode as FBCoshNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBCoshNodeModule.Start(builder) + FBCoshNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBCoshNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBCoshNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.CoshNode + + def _build_ArcsinhNode( + self, builder: flatbuffers.Builder, op: ArcsinhNode + ) -> Tuple[int, int]: + """Auto-generated builder for ArcsinhNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ArcsinhNode as FBArcsinhNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBArcsinhNodeModule.Start(builder) + FBArcsinhNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBArcsinhNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBArcsinhNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ArcsinhNode + + def _build_ArccoshNode( + self, builder: flatbuffers.Builder, op: ArccoshNode + ) -> Tuple[int, int]: + """Auto-generated builder for ArccoshNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ArccoshNode as FBArccoshNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBArccoshNodeModule.Start(builder) + FBArccoshNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBArccoshNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBArccoshNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ArccoshNode + + def _build_ArctanhNode( + self, builder: flatbuffers.Builder, op: ArctanhNode + ) -> Tuple[int, int]: + """Auto-generated builder for ArctanhNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ArctanhNode as FBArctanhNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBArctanhNodeModule.Start(builder) + FBArctanhNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBArctanhNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBArctanhNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ArctanhNode + + def _build_Log2Node( + self, builder: flatbuffers.Builder, op: Log2Node + ) -> Tuple[int, int]: + """Auto-generated builder for Log2Node.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import Log2Node as FBLog2NodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBLog2NodeModule.Start(builder) + FBLog2NodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBLog2NodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBLog2NodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.Log2Node + + def _build_Log10Node( + self, builder: flatbuffers.Builder, op: Log10Node + ) -> Tuple[int, int]: + """Auto-generated builder for Log10Node.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import Log10Node as FBLog10NodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBLog10NodeModule.Start(builder) + FBLog10NodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBLog10NodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBLog10NodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.Log10Node + + def _build_Log1pNode( + self, builder: flatbuffers.Builder, op: Log1pNode + ) -> Tuple[int, int]: + """Auto-generated builder for Log1pNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import Log1pNode as FBLog1pNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBLog1pNodeModule.Start(builder) + FBLog1pNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBLog1pNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBLog1pNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.Log1pNode + + def _build_ErfNode( + self, builder: flatbuffers.Builder, op: ErfNode + ) -> Tuple[int, int]: + """Auto-generated builder for ErfNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ErfNode as FBErfNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBErfNodeModule.Start(builder) + FBErfNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBErfNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBErfNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ErfNode + + def _build_Expm1Node( + self, builder: flatbuffers.Builder, op: Expm1Node + ) -> Tuple[int, int]: + """Auto-generated builder for Expm1Node.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import Expm1Node as FBExpm1NodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBExpm1NodeModule.Start(builder) + FBExpm1NodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBExpm1NodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBExpm1NodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.Expm1Node + + def _build_RoundNode( + self, builder: flatbuffers.Builder, op: RoundNode + ) -> Tuple[int, int]: + """Auto-generated builder for RoundNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import RoundNode as FBRoundNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBRoundNodeModule.Start(builder) + FBRoundNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBRoundNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBRoundNodeModule.AddDecimals(builder, op.decimals) + offset = FBRoundNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.RoundNode + + def _build_ReciprocalNode( + self, builder: flatbuffers.Builder, op: ReciprocalNode + ) -> Tuple[int, int]: + """Auto-generated builder for ReciprocalNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ReciprocalNode as FBReciprocalNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBReciprocalNodeModule.Start(builder) + FBReciprocalNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBReciprocalNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBReciprocalNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ReciprocalNode + + def _build_SqrtNode( + self, builder: flatbuffers.Builder, op: SqrtNode + ) -> Tuple[int, int]: + """Auto-generated builder for SqrtNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SqrtNode as FBSqrtNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBSqrtNodeModule.Start(builder) + FBSqrtNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBSqrtNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBSqrtNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SqrtNode + + def _build_AbsNode( + self, builder: flatbuffers.Builder, op: AbsNode + ) -> Tuple[int, int]: + """Auto-generated builder for AbsNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import AbsNode as FBAbsNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBAbsNodeModule.Start(builder) + FBAbsNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBAbsNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBAbsNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.AbsNode + + def _build_NegNode( + self, builder: flatbuffers.Builder, op: NegNode + ) -> Tuple[int, int]: + """Auto-generated builder for NegNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import NegNode as FBNegNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBNegNodeModule.Start(builder) + FBNegNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBNegNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBNegNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.NegNode + + def _build_Atan2Node( + self, builder: flatbuffers.Builder, op: Atan2Node + ) -> Tuple[int, int]: + """Auto-generated builder for Atan2Node.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import Atan2Node as FBAtan2NodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBAtan2NodeModule.Start(builder) + FBAtan2NodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBAtan2NodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBAtan2NodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBAtan2NodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.Atan2Node + + def _build_LogAddExpNode( + self, builder: flatbuffers.Builder, op: LogAddExpNode + ) -> Tuple[int, int]: + """Auto-generated builder for LogAddExpNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import LogAddExpNode as FBLogAddExpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBLogAddExpNodeModule.Start(builder) + FBLogAddExpNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBLogAddExpNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBLogAddExpNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBLogAddExpNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.LogAddExpNode + + def _build_FloorDivideNode( + self, builder: flatbuffers.Builder, op: FloorDivideNode + ) -> Tuple[int, int]: + """Auto-generated builder for FloorDivideNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import FloorDivideNode as FBFloorDivideNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBFloorDivideNodeModule.Start(builder) + FBFloorDivideNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBFloorDivideNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBFloorDivideNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBFloorDivideNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.FloorDivideNode + + def _build_RemainderNode( + self, builder: flatbuffers.Builder, op: RemainderNode + ) -> Tuple[int, int]: + """Auto-generated builder for RemainderNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import RemainderNode as FBRemainderNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBRemainderNodeModule.Start(builder) + FBRemainderNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBRemainderNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBRemainderNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBRemainderNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.RemainderNode + + def _build_PowerNode( + self, builder: flatbuffers.Builder, op: PowerNode + ) -> Tuple[int, int]: + """Auto-generated builder for PowerNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import PowerNode as FBPowerNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBPowerNodeModule.Start(builder) + FBPowerNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBPowerNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBPowerNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + offset = FBPowerNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.PowerNode + + def _build_LogSumExpNode( + self, builder: flatbuffers.Builder, op: LogSumExpNode + ) -> Tuple[int, int]: + """Auto-generated builder for LogSumExpNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import LogSumExpNode as FBLogSumExpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + axes_vec = _build_int_vector(builder, op.axes) if op.axes is not None else None + + FBLogSumExpNodeModule.Start(builder) + FBLogSumExpNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBLogSumExpNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if axes_vec is not None: + FBLogSumExpNodeModule.AddAxes(builder, axes_vec) + FBLogSumExpNodeModule.AddKeepdims(builder, op.keepdims) + offset = FBLogSumExpNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.LogSumExpNode + + def _build_SumNode( + self, builder: flatbuffers.Builder, op: SumNode + ) -> Tuple[int, int]: + """Auto-generated builder for SumNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import SumNode as FBSumNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + axes_vec = _build_int_vector(builder, op.axes) if op.axes is not None else None + + FBSumNodeModule.Start(builder) + FBSumNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBSumNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if axes_vec is not None: + FBSumNodeModule.AddAxes(builder, axes_vec) + FBSumNodeModule.AddKeepdims(builder, op.keepdims) + offset = FBSumNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.SumNode + + def _build_MeanNode( + self, builder: flatbuffers.Builder, op: MeanNode + ) -> Tuple[int, int]: + """Auto-generated builder for MeanNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import MeanNode as FBMeanNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + axes_vec = _build_int_vector(builder, op.axes) if op.axes is not None else None + + FBMeanNodeModule.Start(builder) + FBMeanNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBMeanNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if axes_vec is not None: + FBMeanNodeModule.AddAxes(builder, axes_vec) + FBMeanNodeModule.AddKeepdims(builder, op.keepdims) + offset = FBMeanNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.MeanNode + + def _build_VarNode( + self, builder: flatbuffers.Builder, op: VarNode + ) -> Tuple[int, int]: + """Auto-generated builder for VarNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import VarNode as FBVarNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + axes_vec = _build_int_vector(builder, op.axes) if op.axes is not None else None + + FBVarNodeModule.Start(builder) + FBVarNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBVarNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if axes_vec is not None: + FBVarNodeModule.AddAxes(builder, axes_vec) + FBVarNodeModule.AddKeepdims(builder, op.keepdims) + FBVarNodeModule.AddDdof(builder, op.ddof) + offset = FBVarNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.VarNode + + def _build_StdNode( + self, builder: flatbuffers.Builder, op: StdNode + ) -> Tuple[int, int]: + """Auto-generated builder for StdNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import StdNode as FBStdNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + axes_vec = _build_int_vector(builder, op.axes) if op.axes is not None else None + + FBStdNodeModule.Start(builder) + FBStdNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBStdNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if axes_vec is not None: + FBStdNodeModule.AddAxes(builder, axes_vec) + FBStdNodeModule.AddKeepdims(builder, op.keepdims) + FBStdNodeModule.AddDdof(builder, op.ddof) + offset = FBStdNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.StdNode + + def _build_ProdNode( + self, builder: flatbuffers.Builder, op: ProdNode + ) -> Tuple[int, int]: + """Auto-generated builder for ProdNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ProdNode as FBProdNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + axes_vec = _build_int_vector(builder, op.axes) if op.axes is not None else None + + FBProdNodeModule.Start(builder) + FBProdNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBProdNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if axes_vec is not None: + FBProdNodeModule.AddAxes(builder, axes_vec) + FBProdNodeModule.AddKeepdims(builder, op.keepdims) + offset = FBProdNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ProdNode + + def _build_MaxNode( + self, builder: flatbuffers.Builder, op: MaxNode + ) -> Tuple[int, int]: + """Auto-generated builder for MaxNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import MaxNode as FBMaxNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + axes_vec = _build_int_vector(builder, op.axes) if op.axes is not None else None + + FBMaxNodeModule.Start(builder) + FBMaxNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBMaxNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if axes_vec is not None: + FBMaxNodeModule.AddAxes(builder, axes_vec) + FBMaxNodeModule.AddKeepdims(builder, op.keepdims) + offset = FBMaxNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.MaxNode + + def _build_MinNode( + self, builder: flatbuffers.Builder, op: MinNode + ) -> Tuple[int, int]: + """Auto-generated builder for MinNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import MinNode as FBMinNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + axes_vec = _build_int_vector(builder, op.axes) if op.axes is not None else None + + FBMinNodeModule.Start(builder) + FBMinNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBMinNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if axes_vec is not None: + FBMinNodeModule.AddAxes(builder, axes_vec) + FBMinNodeModule.AddKeepdims(builder, op.keepdims) + offset = FBMinNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.MinNode + + def _build_ArgminNode( + self, builder: flatbuffers.Builder, op: ArgminNode + ) -> Tuple[int, int]: + """Auto-generated builder for ArgminNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ArgminNode as FBArgminNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBArgminNodeModule.Start(builder) + FBArgminNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBArgminNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBArgminNodeModule.AddAxis(builder, op.axis) + FBArgminNodeModule.AddKeepdims(builder, op.keepdims) + offset = FBArgminNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ArgminNode + + def _build_MedianNode( + self, builder: flatbuffers.Builder, op: MedianNode + ) -> Tuple[int, int]: + """Auto-generated builder for MedianNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import MedianNode as FBMedianNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + axes_vec = _build_int_vector(builder, op.axes) if op.axes is not None else None + + FBMedianNodeModule.Start(builder) + FBMedianNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBMedianNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if axes_vec is not None: + FBMedianNodeModule.AddAxes(builder, axes_vec) + FBMedianNodeModule.AddKeepdims(builder, op.keepdims) + offset = FBMedianNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.MedianNode + + def _build_GatherMmNode( + self, builder: flatbuffers.Builder, op: GatherMmNode + ) -> Tuple[int, int]: + """Auto-generated builder for GatherMmNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import GatherMmNode as FBGatherMmNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + FBGatherMmNodeModule.Start(builder) + FBGatherMmNodeModule.AddA(builder, CreateTid(builder, op.a.idx)) + FBGatherMmNodeModule.AddB(builder, CreateTid(builder, op.b.idx)) + FBGatherMmNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + if op.lhs_indices is not None: + FBGatherMmNodeModule.AddLhsIndices(builder, CreateTid(builder, op.lhs_indices.idx)) + if op.rhs_indices is not None: + FBGatherMmNodeModule.AddRhsIndices(builder, CreateTid(builder, op.rhs_indices.idx)) + FBGatherMmNodeModule.AddSortedIndices(builder, op.sorted_indices) + offset = FBGatherMmNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.GatherMmNode + + def _build_GatherQmmNode( + self, builder: flatbuffers.Builder, op: GatherQmmNode + ) -> Tuple[int, int]: + """Auto-generated builder for GatherQmmNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import GatherQmmNode as FBGatherQmmNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + mode_off = builder.CreateString(op.mode) + + FBGatherQmmNodeModule.Start(builder) + FBGatherQmmNodeModule.AddX(builder, CreateTid(builder, op.x.idx)) + FBGatherQmmNodeModule.AddW(builder, CreateTid(builder, op.w.idx)) + FBGatherQmmNodeModule.AddScales(builder, CreateTid(builder, op.scales.idx)) + FBGatherQmmNodeModule.AddOut(builder, CreateTid(builder, op.out.idx)) + FBGatherQmmNodeModule.AddMode(builder, mode_off) + if op.biases is not None: + FBGatherQmmNodeModule.AddBiases(builder, CreateTid(builder, op.biases.idx)) + if op.lhs_indices is not None: + FBGatherQmmNodeModule.AddLhsIndices(builder, CreateTid(builder, op.lhs_indices.idx)) + if op.rhs_indices is not None: + FBGatherQmmNodeModule.AddRhsIndices(builder, CreateTid(builder, op.rhs_indices.idx)) + FBGatherQmmNodeModule.AddTranspose(builder, op.transpose) + FBGatherQmmNodeModule.AddGroupSize(builder, op.group_size) + FBGatherQmmNodeModule.AddBits(builder, op.bits) + FBGatherQmmNodeModule.AddSortedIndices(builder, op.sorted_indices) + offset = FBGatherQmmNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.GatherQmmNode + + def _build_ScanNode( + self, builder: flatbuffers.Builder, op: ScanNode + ) -> Tuple[int, int]: + """Auto-generated builder for ScanNode.""" + # Import the MODULE (not class) to access builder functions like Start(), Add*(), End() + from executorch.backends.mlx.serialization._generated.mlx_delegate import ScanNode as FBScanNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule + from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid + + originals_vec = self._build_tid_vector(builder, op.originals) + sliced_vec = self._build_tid_vector(builder, op.sliced) + outputs_vec = self._build_tid_vector(builder, op.outputs) + carry_vec = self._build_tid_vector(builder, op.carry) + + FBScanNodeModule.Start(builder) + FBScanNodeModule.AddOriginals(builder, originals_vec) + FBScanNodeModule.AddSliced(builder, sliced_vec) + FBScanNodeModule.AddOutputs(builder, outputs_vec) + FBScanNodeModule.AddCarry(builder, carry_vec) + FBScanNodeModule.AddBodyChainIdx(builder, op.body_chain_idx) + FBScanNodeModule.AddScanAxis(builder, op.scan_axis) + offset = FBScanNodeModule.End(builder) + return offset, FBOpNodeModule.OpNode.ScanNode diff --git a/backends/mlx/serialization/mlx_graph_schema.py b/backends/mlx/serialization/mlx_graph_schema.py new file mode 100644 index 00000000000..530c68d9cd2 --- /dev/null +++ b/backends/mlx/serialization/mlx_graph_schema.py @@ -0,0 +1,1304 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# ============================================================================ +# AUTO-GENERATED FILE - DO NOT EDIT MANUALLY +# ============================================================================ +# +# This file was generated from schema.fbs by the MLX delegate code generator. +# +# Source: backends/mlx/serialization/schema.fbs +# Generator: backends/mlx/serialization/generate.py +# +# To regenerate, run from the executorch root: +# python backends/mlx/serialization/generate.py +# +# ============================================================================ + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import IntEnum +from typing import List, Optional, Union + + +# ============================================================================ +# Enums +# ============================================================================ + +class SlotType(IntEnum): + TensorSlot = 0 + IntValueSlot = 1 + FloatValueSlot = 2 + BoolValueSlot = 3 + + +# ============================================================================ +# Core types +# ============================================================================ + +@dataclass +class Tid: + idx: Optional[int] + + +@dataclass +class Vid: + idx: Optional[int] + + +@dataclass +class FloatOrVid: + """Represents either a literal float or a runtime Vid reference.""" + literal: float = 0.0 + vid: Optional[Vid] = None + is_vid: bool = False + + @classmethod + def from_literal(cls, value: float) -> "FloatOrVid": + """Create a FloatOrVid from a literal float.""" + return cls(literal=value, is_vid=False) + + @classmethod + def from_vid(cls, vid: Vid) -> "FloatOrVid": + """Create a FloatOrVid from a Vid reference.""" + return cls(vid=vid, is_vid=True) + + +@dataclass +class IntOrVid: + """Represents either a literal integer or a runtime Vid reference.""" + literal: int = 0 + vid: Optional[Vid] = None + is_vid: bool = False + + @classmethod + def from_literal(cls, value: int) -> "IntOrVid": + """Create a IntOrVid from a literal integer.""" + return cls(literal=value, is_vid=False) + + @classmethod + def from_vid(cls, vid: Vid) -> "IntOrVid": + """Create a IntOrVid from a Vid reference.""" + return cls(vid=vid, is_vid=True) + + +@dataclass +class IntOrVidOrTid: + """Represents either a literal integer or a runtime Vid reference.""" + literal: int = 0 + vid: Optional[Vid] = None + tid: Optional[Tid] = None + kind: int = 0 + + @classmethod + def from_literal(cls, value: int) -> "IntOrVidOrTid": + """Create a IntOrVidOrTid from a literal integer.""" + return cls(literal=value, kind=0) + + @classmethod + def from_vid(cls, vid: Vid) -> "IntOrVidOrTid": + """Create a IntOrVidOrTid from a Vid reference.""" + return cls(vid=vid, kind=1) + + @classmethod + def from_tid(cls, tid: Tid) -> "IntOrVidOrTid": + """Create a IntOrVidOrTid from a Tid tensor reference.""" + return cls(tid=tid, kind=2) + + +@dataclass +class VidOrTid: + """Represents either a tensor reference or a runtime Vid reference.""" + vid: Optional[Vid] = None + tid: Optional[Tid] = None + is_vid: bool = False + + @classmethod + def from_tid(cls, value: Tid) -> "VidOrTid": + """Create a VidOrTid from a tensor reference.""" + return cls(tid=value, is_vid=False) + + @classmethod + def from_vid(cls, vid: Vid) -> "VidOrTid": + """Create a VidOrTid from a Vid reference.""" + return cls(vid=vid, is_vid=True) + + @classmethod + def from_tid(cls, tid: Tid) -> "VidOrTid": + """Create a VidOrTid from a Tid tensor reference.""" + return cls(tid=tid, is_vid=False) + + +@dataclass +class ShapeDim: + value: int = -1 + min_value: int = 0 + max_value: int = -1 + + +@dataclass +class SlotVariant: + slot_type: SlotType = SlotType.TensorSlot + idx: Optional[int] = None + + +@dataclass +class NamedSlot: + name: str + slot: SlotVariant + + +@dataclass +class TensorMeta: + shape: List[ShapeDim] + scalar_type: Optional[int] = None + dim_order: Optional[List[int]] = None + + +# ============================================================================ +# Op nodes +# ============================================================================ + +@dataclass +class NoopNode: + pass + + +@dataclass +class IdCopyNode: + x: Tid + out: Tid + + +@dataclass +class AddmmNode: + mat1: Tid + mat2: Tid + out: Tid + alpha: float = 1.0 + beta: float = 1.0 + bias: Optional[Tid] = None + + +@dataclass +class ItemIntNode: + x: Tid + out: Vid + + +@dataclass +class ExpandDimsNode: + x: Tid + out: Tid + axis: Optional[int] = None + + +@dataclass +class TileNode: + x: Tid + out: Tid + reps: List[IntOrVid] + + +@dataclass +class TakeAlongAxisNode: + x: Tid + indices: Tid + out: Tid + axis: Optional[int] = None + + +@dataclass +class TakeNode: + x: Tid + out: Tid + index: IntOrVidOrTid + axis: Optional[int] = None + + +@dataclass +class RMSNormNode: + x: Tid + out: Tid + weight: Optional[Tid] = None + eps: Optional[float] = None + + +@dataclass +class LayerNormNode: + x: Tid + out: Tid + weight: Optional[Tid] = None + bias: Optional[Tid] = None + eps: Optional[float] = None + + +@dataclass +class RopeNode: + x: Tid + out: Tid + offset: VidOrTid + traditional: bool = False + base: float = 500000.0 + scale: float = 1.0 + dims: Optional[int] = None + freqs: Optional[Tid] = None + + +@dataclass +class SdpaNode: + q: Tid + k: Tid + v: Tid + out: Tid + causal: bool = False + scale: Optional[float] = None + mask: Optional[Tid] = None + + +@dataclass +class AddNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class AddIntNode: + a: IntOrVid + b: IntOrVid + out: Vid + + +@dataclass +class SubtractIntNode: + a: IntOrVid + b: IntOrVid + out: Vid + + +@dataclass +class MultiplyIntNode: + a: IntOrVid + b: IntOrVid + out: Vid + + +@dataclass +class FloorDivideIntNode: + a: IntOrVid + b: IntOrVid + out: Vid + + +@dataclass +class ModIntNode: + a: IntOrVid + b: IntOrVid + out: Vid + + +@dataclass +class SymSizeNode: + a: Tid + out: Vid + dim: Optional[int] = None + + +@dataclass +class MultiplyNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class DivideNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class SubtractNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class Conv1DNode: + x: Tid + w: Tid + out: Tid + stride: int = 1 + padding: int = 0 + dilation: int = 1 + groups: int = 1 + + +@dataclass +class Conv2DNode: + x: Tid + w: Tid + out: Tid + stride_h: int = 1 + stride_w: int = 1 + padding_h: int = 0 + padding_w: int = 0 + dilation_h: int = 1 + dilation_w: int = 1 + groups: int = 1 + + +@dataclass +class Conv3DNode: + x: Tid + w: Tid + out: Tid + stride_d: int = 1 + stride_h: int = 1 + stride_w: int = 1 + padding_d: int = 0 + padding_h: int = 0 + padding_w: int = 0 + dilation_d: int = 1 + dilation_h: int = 1 + dilation_w: int = 1 + groups: int = 1 + + +@dataclass +class ConvTranspose1DNode: + x: Tid + w: Tid + out: Tid + stride: int = 1 + padding: int = 0 + dilation: int = 1 + output_padding: int = 0 + groups: int = 1 + + +@dataclass +class ConvTranspose2DNode: + x: Tid + w: Tid + out: Tid + stride_h: int = 1 + stride_w: int = 1 + padding_h: int = 0 + padding_w: int = 0 + dilation_h: int = 1 + dilation_w: int = 1 + output_padding_h: int = 0 + output_padding_w: int = 0 + groups: int = 1 + + +@dataclass +class ConvTranspose3DNode: + x: Tid + w: Tid + out: Tid + stride_d: int = 1 + stride_h: int = 1 + stride_w: int = 1 + padding_d: int = 0 + padding_h: int = 0 + padding_w: int = 0 + dilation_d: int = 1 + dilation_h: int = 1 + dilation_w: int = 1 + output_padding_d: int = 0 + output_padding_h: int = 0 + output_padding_w: int = 0 + groups: int = 1 + + +@dataclass +class GeluNode: + x: Tid + out: Tid + approximate: str + + +@dataclass +class ARangeNode: + out: Tid + start: IntOrVid + stop: IntOrVid + step: IntOrVid + scalar_type: int = None + + +@dataclass +class SiluNode: + x: Tid + out: Tid + + +@dataclass +class SigmoidNode: + x: Tid + out: Tid + + +@dataclass +class TanhNode: + x: Tid + out: Tid + + +@dataclass +class SqueezeNode: + x: Tid + out: Tid + dims: Optional[List[int]] = None + + +@dataclass +class SplitNode: + x: Tid + outs: List[Tid] + sizes: List[IntOrVid] + axis: Optional[int] = None + + +@dataclass +class RsqrtNode: + x: Tid + out: Tid + + +@dataclass +class MaximumNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class MinimumNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class LogNode: + x: Tid + out: Tid + + +@dataclass +class SoftmaxNode: + x: Tid + out: Tid + precise: bool = False + axis: Optional[int] = None + + +@dataclass +class BroadcastToNode: + x: Tid + out: Tid + shape: List[IntOrVid] + + +@dataclass +class PadNode: + x: Tid + out: Tid + pad_width: List[IntOrVid] + mode: str + constant_value: float = 0.0 + + +@dataclass +class WhereNode: + condition: Tid + x: Tid + y: Tid + out: Tid + + +@dataclass +class ReshapeNode: + x: Tid + out: Tid + shape: List[IntOrVid] + + +@dataclass +class TransposeNode: + x: Tid + out: Tid + perm: List[int] + + +@dataclass +class AsStridedNode: + x: Tid + out: Tid + shape: List[IntOrVid] + strides: List[IntOrVid] + offset: int = 0 + + +@dataclass +class ContiguousNode: + x: Tid + out: Tid + + +@dataclass +class GatherNode: + x: Tid + indices: List[Tid] + out: Tid + axes: List[int] + slice_sizes: List[int] + + +@dataclass +class SliceNode: + x: Tid + out: Tid + axis: IntOrVid + start: IntOrVid + stop: IntOrVid + step: int = 1 + + +@dataclass +class AsTypeNode: + x: Tid + out: Tid + scalar_type: Optional[int] = None + + +@dataclass +class QuantizedMatmulNode: + x: Tid + w: Tid + scales: Tid + out: Tid + mode: str + transpose: bool = True + biases: Optional[Tid] = None + group_size: Optional[int] = None + bits: Optional[int] = None + + +@dataclass +class ScatterAddNode: + x: Tid + indices: Tid + updates: Tid + out: Tid + axis: Optional[int] = None + + +@dataclass +class ConcatenateNode: + tensors: List[Tid] + out: Tid + axis: Optional[int] = None + + +@dataclass +class FullNode: + out: Tid + shape: List[IntOrVid] + v: FloatOrVid + scalar_type: Optional[int] = None + + +@dataclass +class FullLikeNode: + x: Tid + out: Tid + v: FloatOrVid + scalar_type: int = None + + +@dataclass +class ArgmaxNode: + x: Tid + out: Tid + keepdims: bool = False + axis: Optional[int] = None + + +@dataclass +class SliceUpdateNode: + dst: Tid + update: Tid + out: Tid + axis: IntOrVid + start: IntOrVid + stop: IntOrVid + step: int = 1 + + +@dataclass +class IndexCopyNode: + dst: Tid + update: Tid + indices: Tid + out: Tid + axis: Optional[int] = None + + +@dataclass +class DequantizeNode: + w: Tid + scales: Tid + out: Tid + mode: str + dtype: int = None + biases: Optional[Tid] = None + group_size: Optional[int] = None + bits: Optional[int] = None + global_scale: Optional[Tid] = None + + +@dataclass +class LessNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class LessEqualNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class GreaterNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class GreaterEqualNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class EqualNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class NotEqualNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class LogicalNotNode: + x: Tid + out: Tid + + +@dataclass +class LogicalAndNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class LogicalOrNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class TriNode: + out: Tid + n: IntOrVid + m: IntOrVid + k: int = 0 + scalar_type: Optional[int] = None + + +@dataclass +class TrilNode: + x: Tid + out: Tid + k: int = 0 + + +@dataclass +class TriuNode: + x: Tid + out: Tid + k: int = 0 + + +@dataclass +class ClipNode: + x: Tid + out: Tid + a_min: Optional[Tid] = None + a_max: Optional[Tid] = None + + +@dataclass +class CumsumNode: + x: Tid + out: Tid + reverse: bool = False + inclusive: bool = True + axis: Optional[int] = None + + +@dataclass +class StackNode: + tensors: List[Tid] + out: Tid + axis: int = 0 + + +@dataclass +class SignNode: + x: Tid + out: Tid + + +@dataclass +class AnyNode: + x: Tid + out: Tid + keepdims: bool = False + axes: Optional[List[int]] = None + + +@dataclass +class AllNode: + x: Tid + out: Tid + keepdims: bool = False + axes: Optional[List[int]] = None + + +@dataclass +class RepeatNode: + x: Tid + out: Tid + repeats: IntOrVid + axis: Optional[int] = None + + +@dataclass +class SortNode: + x: Tid + out: Tid + axis: Optional[int] = None + + +@dataclass +class ArgsortNode: + x: Tid + out: Tid + axis: Optional[int] = None + + +@dataclass +class PartitionNode: + x: Tid + out: Tid + kth: IntOrVid + axis: Optional[int] = None + + +@dataclass +class ArgPartitionNode: + x: Tid + out: Tid + kth: IntOrVid + axis: Optional[int] = None + + +@dataclass +class FloorNode: + x: Tid + out: Tid + + +@dataclass +class CeilNode: + x: Tid + out: Tid + + +@dataclass +class SquareNode: + x: Tid + out: Tid + + +@dataclass +class ExpNode: + x: Tid + out: Tid + + +@dataclass +class SinNode: + x: Tid + out: Tid + + +@dataclass +class CosNode: + x: Tid + out: Tid + + +@dataclass +class TanNode: + x: Tid + out: Tid + + +@dataclass +class ArcsinNode: + x: Tid + out: Tid + + +@dataclass +class ArccosNode: + x: Tid + out: Tid + + +@dataclass +class ArctanNode: + x: Tid + out: Tid + + +@dataclass +class SinhNode: + x: Tid + out: Tid + + +@dataclass +class CoshNode: + x: Tid + out: Tid + + +@dataclass +class ArcsinhNode: + x: Tid + out: Tid + + +@dataclass +class ArccoshNode: + x: Tid + out: Tid + + +@dataclass +class ArctanhNode: + x: Tid + out: Tid + + +@dataclass +class Log2Node: + x: Tid + out: Tid + + +@dataclass +class Log10Node: + x: Tid + out: Tid + + +@dataclass +class Log1pNode: + x: Tid + out: Tid + + +@dataclass +class ErfNode: + x: Tid + out: Tid + + +@dataclass +class Expm1Node: + x: Tid + out: Tid + + +@dataclass +class RoundNode: + x: Tid + out: Tid + decimals: int = 0 + + +@dataclass +class ReciprocalNode: + x: Tid + out: Tid + + +@dataclass +class SqrtNode: + x: Tid + out: Tid + + +@dataclass +class AbsNode: + x: Tid + out: Tid + + +@dataclass +class NegNode: + x: Tid + out: Tid + + +@dataclass +class Atan2Node: + a: Tid + b: Tid + out: Tid + + +@dataclass +class LogAddExpNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class FloorDivideNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class RemainderNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class PowerNode: + a: Tid + b: Tid + out: Tid + + +@dataclass +class LogSumExpNode: + x: Tid + out: Tid + keepdims: bool = False + axes: Optional[List[int]] = None + + +@dataclass +class SumNode: + x: Tid + out: Tid + keepdims: bool = False + axes: Optional[List[int]] = None + + +@dataclass +class MeanNode: + x: Tid + out: Tid + keepdims: bool = False + axes: Optional[List[int]] = None + + +@dataclass +class VarNode: + x: Tid + out: Tid + keepdims: bool = False + ddof: int = 0 + axes: Optional[List[int]] = None + + +@dataclass +class StdNode: + x: Tid + out: Tid + keepdims: bool = False + ddof: int = 0 + axes: Optional[List[int]] = None + + +@dataclass +class ProdNode: + x: Tid + out: Tid + keepdims: bool = False + axes: Optional[List[int]] = None + + +@dataclass +class MaxNode: + x: Tid + out: Tid + keepdims: bool = False + axes: Optional[List[int]] = None + + +@dataclass +class MinNode: + x: Tid + out: Tid + keepdims: bool = False + axes: Optional[List[int]] = None + + +@dataclass +class ArgminNode: + x: Tid + out: Tid + keepdims: bool = False + axis: Optional[int] = None + + +@dataclass +class MedianNode: + x: Tid + out: Tid + keepdims: bool = False + axes: Optional[List[int]] = None + + +@dataclass +class GatherMmNode: + a: Tid + b: Tid + out: Tid + sorted_indices: bool = False + lhs_indices: Optional[Tid] = None + rhs_indices: Optional[Tid] = None + + +@dataclass +class GatherQmmNode: + x: Tid + w: Tid + scales: Tid + out: Tid + mode: str + transpose: bool = True + sorted_indices: bool = False + biases: Optional[Tid] = None + lhs_indices: Optional[Tid] = None + rhs_indices: Optional[Tid] = None + group_size: Optional[int] = None + bits: Optional[int] = None + + +@dataclass +class ScanNode: + originals: List[Tid] + sliced: List[Tid] + outputs: List[Tid] + carry: List[Tid] + scan_axis: int = 1 + body_chain_idx: Optional[int] = None + + +# Union of all op types +OpNodeUnion = Union[ + NoopNode, + IdCopyNode, + AddmmNode, + ItemIntNode, + ExpandDimsNode, + TileNode, + TakeAlongAxisNode, + TakeNode, + RMSNormNode, + LayerNormNode, + RopeNode, + SdpaNode, + AddNode, + AddIntNode, + SubtractIntNode, + MultiplyIntNode, + FloorDivideIntNode, + ModIntNode, + SymSizeNode, + MultiplyNode, + DivideNode, + SubtractNode, + Conv1DNode, + Conv2DNode, + Conv3DNode, + ConvTranspose1DNode, + ConvTranspose2DNode, + ConvTranspose3DNode, + GeluNode, + ARangeNode, + SiluNode, + SigmoidNode, + TanhNode, + SqueezeNode, + SplitNode, + RsqrtNode, + MaximumNode, + MinimumNode, + LogNode, + SoftmaxNode, + BroadcastToNode, + PadNode, + WhereNode, + ReshapeNode, + TransposeNode, + AsStridedNode, + ContiguousNode, + GatherNode, + SliceNode, + AsTypeNode, + QuantizedMatmulNode, + ScatterAddNode, + ConcatenateNode, + FullNode, + FullLikeNode, + ArgmaxNode, + SliceUpdateNode, + IndexCopyNode, + DequantizeNode, + LessNode, + LessEqualNode, + GreaterNode, + GreaterEqualNode, + EqualNode, + NotEqualNode, + LogicalNotNode, + LogicalAndNode, + LogicalOrNode, + TriNode, + TrilNode, + TriuNode, + ClipNode, + CumsumNode, + StackNode, + SignNode, + AnyNode, + AllNode, + RepeatNode, + SortNode, + ArgsortNode, + PartitionNode, + ArgPartitionNode, + FloorNode, + CeilNode, + SquareNode, + ExpNode, + SinNode, + CosNode, + TanNode, + ArcsinNode, + ArccosNode, + ArctanNode, + SinhNode, + CoshNode, + ArcsinhNode, + ArccoshNode, + ArctanhNode, + Log2Node, + Log10Node, + Log1pNode, + ErfNode, + Expm1Node, + RoundNode, + ReciprocalNode, + SqrtNode, + AbsNode, + NegNode, + Atan2Node, + LogAddExpNode, + FloorDivideNode, + RemainderNode, + PowerNode, + LogSumExpNode, + SumNode, + MeanNode, + VarNode, + StdNode, + ProdNode, + MaxNode, + MinNode, + ArgminNode, + MedianNode, + GatherMmNode, + GatherQmmNode, + ScanNode, +] + +# ============================================================================ +# Container types (reference OpNodeUnion) +# ============================================================================ + +@dataclass +class Instruction: + op: OpNodeUnion + + +@dataclass +class InstructionChain: + instructions: List[Instruction] + + +@dataclass +class MLXGraph: + instruction_chains: List[InstructionChain] + version: Optional[str] = None + num_constant_tensors: int = 0 + num_input_tensors: int = 0 + num_output_tensors: int = 0 + num_mutable_buffer_tensors: int = 0 + num_temp_tensors: int = 0 + num_values: int = 0 + main_chain_idx: int = 0 + init_chain_idx: int = -1 + input_map: Optional[List[SlotVariant]] = None + output_map: Optional[List[SlotVariant]] = None + mutable_buffer_map: Optional[List[SlotVariant]] = None + named_slots: Optional[List[NamedSlot]] = None + tensor_meta: Optional[List[TensorMeta]] = None diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 6e8d6f47db8..9d13f4fa7ac 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -815,6 +815,12 @@ table LogAddExpNode { out: Tid (required); } +table BitwiseOrNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + table FloorDivideNode { a: Tid (required); b: Tid (required); @@ -1114,6 +1120,7 @@ union OpNode { GatherQmmNode, ScanNode, MetalKernelNode + BitwiseOrNode, // BC: Add new op nodes here (append only) } diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index e5ece4931b9..346056f9f5d 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -4205,6 +4205,7 @@ def create_model(self) -> nn.Module: # logical {"op_name": "logical_and", "op_fn": torch.logical_and, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()}, {"op_name": "logical_or", "op_fn": torch.logical_or, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()}, + {"op_name": "bitwise_or", "op_fn": torch.bitwise_or, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.int32, torch.int64], "input_fn_a": _int_input_fn(-100, 100), "input_fn_b": _int_input_fn(-100, 100)}, ] # fmt: on