diff --git a/pyproject.toml b/pyproject.toml index b01e461ea32..00bd8a159cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,9 +125,7 @@ exclude = [ "src/nncf/quantization/algorithms/accuracy_control/rank_functions.py", "src/nncf/quantization/algorithms/accuracy_control/ranker.py", "src/nncf/quantization/algorithms/accuracy_control/subset_selection.py", - "src/nncf/quantization/algorithms/algorithm.py", "src/nncf/quantization/algorithms/bias_correction/algorithm.py", - "src/nncf/quantization/algorithms/bias_correction/backend.py", "src/nncf/quantization/algorithms/bias_correction/onnx_backend.py", "src/nncf/quantization/algorithms/bias_correction/openvino_backend.py", "src/nncf/quantization/algorithms/bias_correction/torch_fx_backend.py", diff --git a/src/nncf/common/utils/os.py b/src/nncf/common/utils/os.py index 59d9e40dd5b..1641ef0d140 100644 --- a/src/nncf/common/utils/os.py +++ b/src/nncf/common/utils/os.py @@ -67,7 +67,7 @@ def get_available_cpu_count(logical: bool = True) -> int: return 1 -def get_available_memory_amount() -> float: +def get_available_memory_amount() -> int: """ :return: Available memory amount (bytes) """ diff --git a/src/nncf/quantization/algorithms/accuracy_control/algorithm.py b/src/nncf/quantization/algorithms/accuracy_control/algorithm.py index 19b5091418e..ce209a18401 100644 --- a/src/nncf/quantization/algorithms/accuracy_control/algorithm.py +++ b/src/nncf/quantization/algorithms/accuracy_control/algorithm.py @@ -29,6 +29,7 @@ from nncf.quantization.algorithms.accuracy_control.backend import AccuracyControlAlgoBackend from nncf.quantization.algorithms.accuracy_control.evaluator import Evaluator from nncf.quantization.algorithms.accuracy_control.evaluator import MetricResults +from nncf.quantization.algorithms.accuracy_control.ranker import GroupToRank from nncf.quantization.algorithms.accuracy_control.ranker import Ranker TModel = TypeVar("TModel") @@ -110,8 +111,8 @@ class QuantizationAccuracyRestorerReport: :param num_iterations: Number of iterations performed. """ - def __init__(self): - self.removed_groups = [] + def __init__(self) -> None: + self.removed_groups: list[GroupToRank] = [] self.removed_all = False self.reached_required_drop = False self.num_quantized_operations = None @@ -355,7 +356,7 @@ def _apply( break nncf_logger.info( - f"Accuracy drop with the new quantization scope is {float(current_accuracy_drop)} ({self.drop_type})" + f"Accuracy drop with the new quantization scope is {current_accuracy_drop} ({self.drop_type})" ) # Continue greedy quantizer remove @@ -479,9 +480,9 @@ def _print_report(report: QuantizationAccuracyRestorerReport, max_num_iterations ) @staticmethod - def _print_completion_message(accuracy_drop: float, drop_type: DropType) -> None: + def _print_completion_message(accuracy_drop: float | None, drop_type: DropType) -> None: if accuracy_drop is None or accuracy_drop < 0: reason = "metric of the quantized model is greater than the metric of the initial model" else: - reason = f"achieved required accuracy drop {float(accuracy_drop)} ({drop_type})" + reason = f"achieved required accuracy drop {accuracy_drop} ({drop_type})" nncf_logger.info(f"Algorithm completed: {reason}") diff --git a/src/nncf/quantization/algorithms/accuracy_control/backend.py b/src/nncf/quantization/algorithms/accuracy_control/backend.py index 9a122bcc272..d862e1c1caa 100644 --- a/src/nncf/quantization/algorithms/accuracy_control/backend.py +++ b/src/nncf/quantization/algorithms/accuracy_control/backend.py @@ -56,7 +56,7 @@ class AccuracyControlAlgoBackend(ABC): @staticmethod @abstractmethod - def get_op_with_weights_metatypes() -> list[OperatorMetatype]: + def get_op_with_weights_metatypes() -> list[type[OperatorMetatype]]: """ Returns a list of operation metatypes that can be reverted to representation with int8 weights. @@ -67,7 +67,7 @@ def get_op_with_weights_metatypes() -> list[OperatorMetatype]: @staticmethod @abstractmethod - def get_quantizer_metatypes() -> list[OperatorMetatype]: + def get_quantizer_metatypes() -> list[type[OperatorMetatype]]: """ Returns a list of quantizer metatypes. @@ -76,7 +76,7 @@ def get_quantizer_metatypes() -> list[OperatorMetatype]: @staticmethod @abstractmethod - def get_const_metatypes() -> list[OperatorMetatype]: + def get_const_metatypes() -> list[type[OperatorMetatype]]: """ Returns a list of constant metatypes. @@ -85,7 +85,7 @@ def get_const_metatypes() -> list[OperatorMetatype]: @staticmethod @abstractmethod - def get_quantizable_metatypes() -> list[OperatorMetatype]: + def get_quantizable_metatypes() -> list[type[OperatorMetatype]]: """ Returns a list of metatypes for operations that may be quantized. @@ -104,7 +104,7 @@ def get_start_nodes_for_activation_path_tracing(nncf_graph: NNCFGraph) -> list[N @staticmethod @abstractmethod - def get_quantize_agnostic_metatypes() -> list[OperatorMetatype]: + def get_quantize_agnostic_metatypes() -> list[type[OperatorMetatype]]: """ Returns a list of quantize agnostic metatypes. @@ -113,7 +113,7 @@ def get_quantize_agnostic_metatypes() -> list[OperatorMetatype]: @staticmethod @abstractmethod - def get_shapeof_metatypes() -> list[OperatorMetatype]: + def get_shapeof_metatypes() -> list[type[OperatorMetatype]]: """ Returns a list of shape of metatypes. @@ -171,7 +171,7 @@ def get_weight_value(node_with_weight: NNCFNode, model: TModel, port_id: int) -> @staticmethod @abstractmethod - def get_weight_tensor_port_ids(node: NNCFNode) -> list[int | None]: + def get_weight_tensor_port_ids(node: NNCFNode) -> list[int]: """ Returns node's input port indices with weights tensors. diff --git a/src/nncf/quantization/algorithms/accuracy_control/onnx_backend.py b/src/nncf/quantization/algorithms/accuracy_control/onnx_backend.py index 4bfc2dcad38..bc92146876a 100644 --- a/src/nncf/quantization/algorithms/accuracy_control/onnx_backend.py +++ b/src/nncf/quantization/algorithms/accuracy_control/onnx_backend.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any import numpy as np import onnx @@ -58,27 +59,27 @@ class ONNXAccuracyControlAlgoBackend(AccuracyControlAlgoBackend): # Metatypes @staticmethod - def get_op_with_weights_metatypes() -> list[ONNXOpMetatype]: + def get_op_with_weights_metatypes() -> list[type[ONNXOpMetatype]]: return OPERATIONS_WITH_WEIGHTS @staticmethod - def get_quantizer_metatypes() -> list[ONNXOpMetatype]: + def get_quantizer_metatypes() -> list[type[ONNXOpMetatype]]: return QUANTIZE_DEQUANTIZE_OPERATIONS @staticmethod - def get_const_metatypes() -> list[ONNXOpMetatype]: + def get_const_metatypes() -> list[type[ONNXOpMetatype]]: return [onnx_metatypes.ONNXConstantMetatype] @staticmethod - def get_quantizable_metatypes() -> list[ONNXOpMetatype]: + def get_quantizable_metatypes() -> list[type[ONNXOpMetatype]]: return INPUTS_QUANTIZABLE_OPERATIONS @staticmethod - def get_quantize_agnostic_metatypes() -> list[ONNXOpMetatype]: + def get_quantize_agnostic_metatypes() -> list[type[ONNXOpMetatype]]: return QUANTIZE_AGNOSTIC_OPERATIONS + [onnx_metatypes.ONNXConcatMetatype] @staticmethod - def get_shapeof_metatypes() -> list[ONNXOpMetatype]: + def get_shapeof_metatypes() -> list[type[ONNXOpMetatype]]: return [onnx_metatypes.ONNXShapeMetatype] @staticmethod @@ -96,17 +97,17 @@ def is_node_with_weight(node: NNCFNode) -> bool: return node.metatype in OPERATIONS_WITH_WEIGHTS and node.layer_attributes.has_weight() @staticmethod - def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: onnx.ModelProto) -> np.ndarray: + def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: onnx.ModelProto) -> np.ndarray[Any, Any]: return get_bias_value(node_with_bias, model) @staticmethod - def get_weight_value(node_with_weight: NNCFNode, model: onnx.ModelProto, port_id: int) -> np.ndarray: + def get_weight_value(node_with_weight: NNCFNode, model: onnx.ModelProto, port_id: int) -> np.ndarray[Any, Any]: assert node_with_weight.layer_attributes.has_weight() weight_name = node_with_weight.layer_attributes.weight_attrs[port_id]["name"] return get_tensor_value(model, weight_name) @staticmethod - def get_weight_tensor_port_ids(node: NNCFNode) -> list[int | None]: + def get_weight_tensor_port_ids(node: NNCFNode) -> list[int]: return list(node.layer_attributes.weight_attrs.keys()) @staticmethod diff --git a/src/nncf/quantization/algorithms/accuracy_control/openvino_backend.py b/src/nncf/quantization/algorithms/accuracy_control/openvino_backend.py index 56b6512b4eb..06696f8430b 100644 --- a/src/nncf/quantization/algorithms/accuracy_control/openvino_backend.py +++ b/src/nncf/quantization/algorithms/accuracy_control/openvino_backend.py @@ -10,6 +10,8 @@ # limitations under the License. +from typing import Any + import numpy as np import openvino as ov from openvino import Type @@ -47,7 +49,7 @@ def __init__(self, model: ov.Model, use_fp32_precision: bool = True): if use_fp32_precision: config = {inference_precision: Type.f32} self._compiled_model = ov.compile_model(model, device_name="CPU", config=config) - self._engine = None + self._engine: OVCompiledModelEngine | None = None @property def model_for_inference(self) -> ov.CompiledModel: @@ -68,27 +70,27 @@ class OVAccuracyControlAlgoBackend(AccuracyControlAlgoBackend): # Metatypes @staticmethod - def get_op_with_weights_metatypes() -> list[OVOpMetatype]: + def get_op_with_weights_metatypes() -> list[type[OVOpMetatype]]: return OPERATIONS_WITH_WEIGHTS @staticmethod - def get_quantizer_metatypes() -> list[OVOpMetatype]: + def get_quantizer_metatypes() -> list[type[OVOpMetatype]]: return FAKE_QUANTIZE_OPERATIONS @staticmethod - def get_const_metatypes() -> list[OVOpMetatype]: + def get_const_metatypes() -> list[type[OVOpMetatype]]: return CONSTANT_OPERATIONS @staticmethod - def get_quantizable_metatypes() -> list[OVOpMetatype]: + def get_quantizable_metatypes() -> list[type[OVOpMetatype]]: return INPUTS_QUANTIZABLE_OPERATIONS @staticmethod - def get_quantize_agnostic_metatypes() -> list[OVOpMetatype]: + def get_quantize_agnostic_metatypes() -> list[type[OVOpMetatype]]: return QUANTIZE_AGNOSTIC_OPERATIONS + [OVConcatMetatype] @staticmethod - def get_shapeof_metatypes() -> list[OVOpMetatype]: + def get_shapeof_metatypes() -> list[type[OVOpMetatype]]: return SHAPEOF_OPERATIONS @staticmethod @@ -106,15 +108,15 @@ def is_node_with_weight(node: NNCFNode) -> bool: return node.metatype in OPERATIONS_WITH_WEIGHTS and isinstance(node.layer_attributes, OVLayerAttributes) @staticmethod - def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> np.ndarray: + def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> np.ndarray[Any, Any]: return get_bias_value(node_with_bias, nncf_graph, model) @staticmethod - def get_weight_value(node_with_weight: NNCFNode, model: ov.Model, port_id: int) -> np.ndarray: + def get_weight_value(node_with_weight: NNCFNode, model: ov.Model, port_id: int) -> np.ndarray[Any, Any]: return get_weight_value(node_with_weight, model, port_id) @staticmethod - def get_weight_tensor_port_ids(node: NNCFNode) -> list[int | None]: + def get_weight_tensor_port_ids(node: NNCFNode) -> list[int]: return node.layer_attributes.get_const_port_ids() @staticmethod diff --git a/src/nncf/quantization/algorithms/bias_correction/algorithm.py b/src/nncf/quantization/algorithms/bias_correction/algorithm.py index a1ed225b972..0cada27be86 100644 --- a/src/nncf/quantization/algorithms/bias_correction/algorithm.py +++ b/src/nncf/quantization/algorithms/bias_correction/algorithm.py @@ -10,6 +10,7 @@ # limitations under the License. from collections import defaultdict +from dataclasses import dataclass from typing import Any, TypeVar import nncf @@ -21,6 +22,7 @@ from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.graph.graph import NNCFNodeName from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationCommand from nncf.common.graph.transformations.layout import TransformationLayout @@ -33,6 +35,7 @@ from nncf.common.utils.backend import copy_model from nncf.common.utils.backend import get_backend from nncf.quantization.algorithms.algorithm import Algorithm +from nncf.quantization.algorithms.bias_correction.backend import BiasCorrectionAlgoBackend from nncf.tensor import Tensor from nncf.tensor import functions as fns @@ -42,6 +45,16 @@ OUTPUT_PORT_OF_NODE = 0 +@dataclass() +class SubgraphData: + """ + Describes the subgraph data for the bias correction algorithm. + """ + + input_ids: set[tuple[NNCFNodeName, int]] + output_ids: set[tuple[NNCFNodeName, int]] + + class BiasCorrection(Algorithm): """ Post-training BiasCorrection algorithm implementation @@ -96,9 +109,9 @@ def __init__( self.inplace_statistics = inplace_statistics self.backend_params = backend_params self.nncf_graph = None - self._backend_entity = None - self._collected_stat_inputs_map = {} - self._fp_inputs = defaultdict(list) + self._backend_entity: BiasCorrectionAlgoBackend | None = None + self._collected_stat_inputs_map: dict[tuple[NNCFNodeName, int], tuple[NNCFNodeName, int]] = {} + self._fp_inputs: dict[tuple[NNCFNodeName, int], list[Tensor]] = defaultdict(list) self._algorithm_key = f"BC_{hash(self)}" if self.apply_for_all_nodes: @@ -148,7 +161,7 @@ def apply( model_copy = self._backend_entity.remove_fq_from_inputs(model_copy, graph_copy) nncf_graph = build_graph(model_copy) - nodes_with_bias = [] + nodes_with_bias: list[NNCFNode] = [] for node in nncf_graph.topological_sort(): if self._is_node_correctable(node, nncf_graph): nodes_with_bias.append(node) @@ -181,7 +194,7 @@ def apply( channel_axis = node.metatype.output_channel_axis if current_bias.ndim > 1: channel_axis = range(current_bias.ndim)[channel_axis] - axes = [i for i in range(current_bias.ndim) if i != channel_axis] + axes = tuple(i for i in range(current_bias.ndim) if i != channel_axis) bias_shift = fns.expand_dims(bias_shift, axes) updated_bias = current_bias + bias_shift @@ -220,7 +233,7 @@ def _is_node_correctable(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool: node, nncf_graph ) - def _find_subgraph_output_ids(self, nncf_graph: NNCFGraph, main_node: NNCFNode) -> list[tuple[str, int]]: + def _find_subgraph_output_ids(self, nncf_graph: NNCFGraph, main_node: NNCFNode) -> list[tuple[NNCFNodeName, int]]: """ The essence of the method is to collect output points for the future subgraph. It is understood that the main_node is the one for which the correction will be performed. @@ -233,8 +246,8 @@ def _find_subgraph_output_ids(self, nncf_graph: NNCFGraph, main_node: NNCFNode) :param main_node: Correctable NNCFNode. :return: Collected output ids. """ - visited_nodes = set() - subgraph_output_ids = [] + visited_nodes: set[NNCFNode] = set() + subgraph_output_ids: list[tuple[NNCFNodeName, int]] = [] edges_queue = nncf_graph.get_output_edges(main_node) while edges_queue: edge = edges_queue.pop() @@ -309,7 +322,7 @@ def _find_subgraph_input_ids( edges_queue.extend(nncf_graph.get_input_edges(node)) return subgraph_input_ids - def _get_subgraph_data_for_node(self, node: NNCFNode, nncf_graph: NNCFGraph) -> dict[str, set[tuple[str, int]]]: + def _get_subgraph_data_for_node(self, node: NNCFNode, nncf_graph: NNCFGraph) -> SubgraphData: """ This method collects necessary data for the specified node and its subgraph. This data contains the nodes (NNCFNode) for the subgraph building @@ -317,9 +330,9 @@ def _get_subgraph_data_for_node(self, node: NNCFNode, nncf_graph: NNCFGraph) -> :param node: NNCFNode instance. This is the main node with bias that would be corrected (or not). :param nncf_graph: NNCFGraph instance for graph analysis. - :return: A dict with the list of the nodes for the subgraph input and statistics collection. + :return: SubgraphData instance containing input and output ids for the subgraph. """ - subgraph_input_ids = [] + subgraph_input_ids: list[tuple[NNCFNodeName, int]] = [] # First, we need to find out the nodes with bias that follow by main node. # To collect statistics for next nodes. @@ -342,26 +355,24 @@ def _get_subgraph_data_for_node(self, node: NNCFNode, nncf_graph: NNCFGraph) -> # In case the outputs were not found during the collection of statistics nodes, # we use the latter as the outputs of the subgraph. - subgraph_data = { - "subgraph_input_ids": set(subgraph_input_ids), - "subgraph_output_ids": set(subgraph_output_ids), - } - - return subgraph_data + return SubgraphData( + input_ids=set(subgraph_input_ids), + output_ids=set(subgraph_output_ids), + ) - def _prepare_subgraph(self, node: NNCFNode, model: TModel, nncf_graph: NNCFGraph, subgraph_data: dict) -> TModel: + def _prepare_subgraph( + self, node: NNCFNode, model: TModel, nncf_graph: NNCFGraph, subgraph_data: SubgraphData + ) -> TModel: """ This method prepares the subgraph from the model for the further inference. :param node: NNCFNode instance for the current layer. :param model: Backend-specific model instance. :param nncf_graph: Instance of NNCFGraph. - :param subgraph_data: A dictionary with the layers for the graph building. + :param subgraph_data: SubgraphData instance with the layers for the graph building. :return: Backend-specific subgraph extracted from the model. """ - extracted_model = self.extract_model( - model, subgraph_data["subgraph_input_ids"], subgraph_data["subgraph_output_ids"] - ) + extracted_model = self.extract_model(model, subgraph_data.input_ids, subgraph_data.output_ids) transformation_layout = TransformationLayout() model_transformer = ModelTransformerFactory.create(extracted_model) @@ -375,13 +386,13 @@ def _prepare_subgraph(self, node: NNCFNode, model: TModel, nncf_graph: NNCFGraph return model_transformer.transform(transformation_layout) def _create_feed_dicts( - self, model: TModel, subgraph_data: dict, statistic_points: StatisticPointsContainer + self, model: TModel, subgraph_data: SubgraphData, statistic_points: StatisticPointsContainer ) -> list[dict]: """ Creates the list of the dictionaries that contains the input data for the model execution. :param model: TModel instance. - :param subgraph_data: A dictionary with the necessary data for current node. + :param subgraph_data: SubgraphData instance with the necessary data for current node. :param statistic_points: StatisticPointsContainer instance. :return: List of the dictionaries with the input data. """ @@ -389,7 +400,7 @@ def _create_feed_dicts( statistics_size = self.subset_size statistics_per_input = {} - for input_node_name, input_port_id in subgraph_data["subgraph_input_ids"]: + for input_node_name, input_port_id in subgraph_data.input_ids: input_tensor_name = self._backend_entity.get_input_name(model, input_node_name, input_port_id) activation_name, output_port_id = self._collected_stat_inputs_map[(input_node_name, input_port_id)] input_fp = self._get_fp_inputs(statistic_points, node_name=activation_name, port_id=output_port_id) @@ -398,7 +409,7 @@ def _create_feed_dicts( for stat_id in range(statistics_size): feed_dict = {} - for input_node_name, input_port_id in subgraph_data["subgraph_input_ids"]: + for input_node_name, input_port_id in subgraph_data.input_ids: input_tensor_name = self._backend_entity.get_input_name(model, input_node_name, input_port_id) # Since we do not use as inputs the layers from which the statistics are gathered, # but those that follow them, we need to take this into account when creating feed dicts. @@ -423,7 +434,7 @@ def _compute_bias_shift( output_tensor_name = self._backend_entity.get_output_name(model, node.node_name, OUTPUT_PORT_OF_NODE) engine = EngineFactory.create(model) channel_axis = node.metatype.output_channel_axis - q_outputs = [] + q_outputs: list[Tensor] = [] for feed_dict in feed_dicts: q_output = engine.infer(feed_dict) q_output = self._backend_entity.process_model_output(q_output, output_tensor_name) @@ -461,37 +472,37 @@ def _correct_bias(self, model: TModel, bias_correction_command: TransformationCo transformation_layout.register(bias_correction_command) return model_transformer.transform(transformation_layout) - def _collect_new_stats(self, model: TModel, feed_dicts: list, subgraph_data: dict) -> None: + def _collect_new_stats(self, model: TModel, feed_dicts: list, subgraph_data: SubgraphData) -> None: """ Updates the self._fp_inputs with the new statistics for the next layers after the correction of the bias for the current. :param model: Backend-specific subgraph. :param feed_dicts: List of dictionaries with the input data for the subgraph. - :param subgraph_data: A dictionary with the needed list of the statistic nodes that will be updated. + :param subgraph_data: A SubgraphData with the needed list of the statistic nodes that will be updated. """ engine = EngineFactory.create(model) for feed_dict in feed_dicts: new_q_output = engine.infer(feed_dict) - for output_node_name, output_id in subgraph_data["subgraph_output_ids"]: + for output_node_name, output_id in subgraph_data.output_ids: output_tensor_name = self._backend_entity.get_output_name(model, output_node_name, output_id) self._fp_inputs[(output_node_name, output_id)].append(Tensor(new_q_output[output_tensor_name])) - def _remove_unnecessary_stats(self, position: int, subgraphs_data: dict[str, dict]) -> None: + def _remove_unnecessary_stats(self, position: int, subgraphs_data: list[SubgraphData]) -> None: """ Removes unnecessary statistics that were collected before to reduce the memory usage. :param position: Zero-based position of the current node that was corrected. - :param subgraphs_data: A dictionary of the data (input & statistic node names) that + :param subgraphs_data: A list of SubgraphData of the data (input & statistic node names) that uses for the sub-graphs creation. """ # Collects list of the statistics that needed for the future layers. needed_stats_list = [] for i in range(position + 1, len(subgraphs_data)): - input_ids = subgraphs_data[i]["subgraph_input_ids"] + input_ids = subgraphs_data[i].input_ids needed_stats_list.extend([self._collected_stat_inputs_map[input_id][0] for input_id in input_ids]) - node_inputs_ids = subgraphs_data[position]["subgraph_input_ids"] + node_inputs_ids = subgraphs_data[position].input_ids for node_input_id in node_inputs_ids: activation_name, port_id = self._collected_stat_inputs_map[node_input_id] input_id = (activation_name, port_id) @@ -531,7 +542,7 @@ def input_filter_func(point): self._fp_inputs[input_id] = input_fp return self._fp_inputs[input_id] - def _get_fp_outputs(self, statistic_points: StatisticPointsContainer, node_name: str) -> Tensor: + def _get_fp_outputs(self, statistic_points: StatisticPointsContainer, node_name: str) -> list[Tensor]: """ Makes out post-layer needed data from the floating-point collected statistics. @@ -540,13 +551,13 @@ def _get_fp_outputs(self, statistic_points: StatisticPointsContainer, node_name: :return: Collected mean tensor data for the further bias calculation. """ - def output_filter_func(point): + def output_filter_func(point: StatisticPoint) -> bool: return self._algorithm_key in point.algorithm_to_tensor_collectors and point.target_point.type in [ TargetType.POST_LAYER_OPERATION, TargetType.OPERATOR_POST_HOOK, ] - output_fp = [] + output_fp: list[Tensor] = [] for tensor_collector in statistic_points.get_algo_statistics_for_node( node_name, output_filter_func, self._algorithm_key ): @@ -582,7 +593,7 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin ) # We must collect the nodes with biases following the model inputs. - biased_after_input_nodes = self._get_biased_after_nodes(graph, model_inputs, model) + biased_after_input_nodes = self._get_biased_after_nodes(graph, model_inputs) for biased_after_input_node in biased_after_input_nodes: # We need to collect activation input to register it for the biased layer as the layer with statistics. @@ -630,17 +641,16 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin return statistic_container - def _get_biased_after_nodes(self, nncf_graph: NNCFGraph, nodes: list[NNCFNode], model: TModel) -> list[NNCFNode]: + def _get_biased_after_nodes(self, nncf_graph: NNCFGraph, nodes: list[NNCFNode]) -> list[NNCFNode]: """ This method finds and returns nodes with the bias in the model that follows after the input nodes. :param nncf_graph: NNCFGraph instance. :param nodes: List of the model inputs as NNCFNodes. - :param model: TModel instance. :return: List of the nodes with bias. """ - def traverse_to_biased(node, condition_container): + def traverse_to_biased(node: NNCFNode, condition_container: set[NNCFNode]) -> None: # A small hack to speed up graph traversal. if node in visited_nodes: return @@ -657,13 +667,13 @@ def traverse_to_biased(node, condition_container): for node_child in node_children: traverse_to_biased(node_child, condition_container) - biased_nodes = set() - visited_nodes = [] + biased_nodes: set[NNCFNode] = set() + visited_nodes: list[NNCFNode] = [] for node in nodes: nncf_logger.debug(f"Looking for biased nodes after {node.node_name} layer.") traverse_to_biased(node, condition_container=biased_nodes) - dependant_nodes = set() + dependant_nodes: set[NNCFNode] = set() # After finding the nodes following the provided layers, we need to make sure # that the found nodes really only depend on the main layers, and not on each other. for biased_node in biased_nodes: @@ -674,7 +684,9 @@ def traverse_to_biased(node, condition_container): return list(biased_nodes - dependant_nodes) - def extract_model(self, model: TModel, input_ids: set[tuple[str, int]], output_ids: set[tuple[str, int]]) -> TModel: + def extract_model( + self, model: TModel, input_ids: set[tuple[NNCFNodeName, int]], output_ids: set[tuple[NNCFNodeName, int]] + ) -> TModel: """ Returns the backend-specific model that bounded by the specified input & output layers. diff --git a/src/nncf/quantization/algorithms/bias_correction/backend.py b/src/nncf/quantization/algorithms/bias_correction/backend.py index 7b34c68e59d..97e04717223 100644 --- a/src/nncf/quantization/algorithms/bias_correction/backend.py +++ b/src/nncf/quantization/algorithms/bias_correction/backend.py @@ -13,8 +13,6 @@ from abc import abstractmethod from typing import TypeVar -import numpy as np - from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetPoint @@ -42,12 +40,15 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) - @staticmethod @abstractmethod - def create_bias_correction_command(node: NNCFNode, bias_value: Tensor) -> TransformationCommand: + def create_bias_correction_command( + node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph + ) -> TransformationCommand: """ Creates backend-specific command to update bias value. :param node: The node for which bias should be updated. :param bias_value: New value for the bias. + :param nncf_graph: NNCFGraph instance. :return: Backend-specific command to update bias value. """ @@ -121,7 +122,7 @@ def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: @staticmethod @abstractmethod - def get_bias_value(node: NNCFNode, model: TModel, nncf_graph: NNCFGraph) -> np.ndarray: + def get_bias_value(node: NNCFNode, model: TModel, nncf_graph: NNCFGraph) -> Tensor: """ Returns bias value in the NumPy format of provided node. diff --git a/src/nncf/quantization/algorithms/bias_correction/onnx_backend.py b/src/nncf/quantization/algorithms/bias_correction/onnx_backend.py index f8e5551701f..3388cba1ea2 100644 --- a/src/nncf/quantization/algorithms/bias_correction/onnx_backend.py +++ b/src/nncf/quantization/algorithms/bias_correction/onnx_backend.py @@ -14,6 +14,7 @@ from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode +from nncf.common.graph.graph import NNCFNodeName from nncf.common.graph.transformations.commands import TargetType from nncf.common.tensor_statistics.builders import get_mean_statistic_collector from nncf.common.tensor_statistics.collectors import TensorCollector @@ -46,11 +47,11 @@ def create_bias_correction_command( def model_extraction_command( input_ids: set[tuple[str, int]], output_ids: set[tuple[str, int]] ) -> ONNXModelExtractionCommand: - return ONNXModelExtractionCommand(input_ids, output_ids) + return ONNXModelExtractionCommand(list(input_ids), list(output_ids)) @staticmethod def output_insertion_command(nncf_graph: NNCFGraph, target_point: ONNXTargetPoint) -> ONNXOutputInsertionCommand: - nncf_input_node_next_nodes = {} + nncf_input_node_next_nodes: dict[NNCFNodeName, list[NNCFNodeName]] = {} for input_node in nncf_graph.get_input_nodes(): next_nodes = nncf_graph.get_next_nodes(input_node) nncf_input_node_next_nodes[input_node.node_name] = [node.node_name for node in next_nodes] @@ -70,7 +71,7 @@ def process_model_output(raw_data: dict, output_name: str) -> Tensor: return Tensor(raw_data[output_name]) @staticmethod - def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> tuple[int, int]: + def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: return 0 @staticmethod diff --git a/src/nncf/quantization/algorithms/bias_correction/openvino_backend.py b/src/nncf/quantization/algorithms/bias_correction/openvino_backend.py index 59fd7e5a272..77f71a2496f 100644 --- a/src/nncf/quantization/algorithms/bias_correction/openvino_backend.py +++ b/src/nncf/quantization/algorithms/bias_correction/openvino_backend.py @@ -47,7 +47,7 @@ def create_bias_correction_command( def model_extraction_command( input_ids: set[tuple[str, int]], output_ids: set[tuple[str, int]] ) -> OVModelExtractionCommand: - return OVModelExtractionCommand(input_ids, output_ids) + return OVModelExtractionCommand(list(input_ids), list(output_ids)) @staticmethod def output_insertion_command(nncf_graph: NNCFGraph, target_point: OVTargetPoint) -> OVOutputInsertionCommand: diff --git a/src/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/src/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index 2c92ad5f44f..4b757c07124 100644 --- a/src/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/src/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -17,6 +17,7 @@ from nncf.common.factory import EngineFactory from nncf.common.factory import ModelTransformerFactory from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetPoint from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout @@ -28,6 +29,7 @@ from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend from nncf.quantization.algorithms.algorithm import Algorithm +from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend from nncf.tensor import Tensor from nncf.tensor import functions as fns @@ -84,7 +86,7 @@ def __init__( self.apply_for_all_nodes = apply_for_all_nodes self.inplace_statistics = inplace_statistics self.backend_params = backend_params - self._backend_entity = None + self._backend_entity: FastBiasCorrectionAlgoBackend | None = None self._algorithm_key = f"FBC_{hash(self)}" if self.apply_for_all_nodes: @@ -145,7 +147,7 @@ def apply( # Fill `node_and_new_bias_value` list. It is a correspondence between nodes # for which we should update bias and new bias values. - node_and_new_bias_value = [] + node_and_new_bias_value: list[tuple[NNCFNode, Tensor]] = [] for node, bias_value in track(node_and_bias_value, description="Applying Fast Bias correction"): node_name = node.node_name diff --git a/src/nncf/quantization/algorithms/fast_bias_correction/backend.py b/src/nncf/quantization/algorithms/fast_bias_correction/backend.py index 950cbcae012..af96d6354b3 100644 --- a/src/nncf/quantization/algorithms/fast_bias_correction/backend.py +++ b/src/nncf/quantization/algorithms/fast_bias_correction/backend.py @@ -11,7 +11,7 @@ from abc import ABC from abc import abstractmethod -from typing import TypeVar +from typing import Any, TypeVar from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode @@ -198,7 +198,7 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple """ def extract_submodel( - self, model_transformer: ModelTransformer, input_id: tuple[str, int], output_id: tuple[str, int] + self, model_transformer: ModelTransformer[Any], input_id: tuple[str, int], output_id: tuple[str, int] ) -> TModel: """ Extracts sub-model using backend-specific ModelTransformer. diff --git a/tests/onnx/quantization/test_bias_correction.py b/tests/onnx/quantization/test_bias_correction.py index a937766fbce..496c1a02c5e 100644 --- a/tests/onnx/quantization/test_bias_correction.py +++ b/tests/onnx/quantization/test_bias_correction.py @@ -19,6 +19,7 @@ from nncf.onnx.graph.model_utils import remove_fq_from_inputs from nncf.onnx.graph.nncf_graph_builder import GraphConverter from nncf.onnx.graph.node_utils import get_bias_value +from nncf.quantization.algorithms.bias_correction.algorithm import SubgraphData from nncf.quantization.algorithms.bias_correction.onnx_backend import ONNXBiasCorrectionAlgoBackend from tests.cross_fw.test_templates.helpers import AddWithInput from tests.cross_fw.test_templates.helpers import ConcatWithInput @@ -103,10 +104,10 @@ def check_bias(model: onnx.ModelProto, ref_biases: dict) -> None: ("/Concat", 1): ("nncf_model_input_0", 0), ("/conv_1/Conv", 0): ("nncf_model_input_0", 0), }, - "subgraph_data": { - "subgraph_input_ids": {("/conv_1/Conv", 0)}, - "subgraph_output_ids": {("/Split", 0), ("/maxpool_1/MaxPool", 0), ("/Split", 1)}, - }, + "subgraph_data": SubgraphData( + input_ids={("/conv_1/Conv", 0)}, + output_ids={("/Split", 0), ("/maxpool_1/MaxPool", 0), ("/Split", 1)}, + ), }, ), ( @@ -118,10 +119,10 @@ def check_bias(model: onnx.ModelProto, ref_biases: dict) -> None: ("/conv_4/Conv", 0): ("/Split", 0), ("/conv_6/Conv", 0): ("/Split", 1), }, - "subgraph_data": { - "subgraph_input_ids": {("/conv_2/Conv", 0)}, - "subgraph_output_ids": {("/Relu_1", 0)}, - }, + "subgraph_data": SubgraphData( + input_ids={("/conv_2/Conv", 0)}, + output_ids={("/Relu_1", 0)}, + ), }, ), ( @@ -134,10 +135,10 @@ def check_bias(model: onnx.ModelProto, ref_biases: dict) -> None: ("/conv_4/Conv", 0): ("/Split", 0), ("/conv_6/Conv", 0): ("/Split", 1), }, - "subgraph_data": { - "subgraph_input_ids": {("/conv_1/Conv", 0), ("/conv_3/Conv", 0)}, - "subgraph_output_ids": {("/Split", 0), ("/Split", 1)}, - }, + "subgraph_data": SubgraphData( + input_ids={("/conv_1/Conv", 0), ("/conv_3/Conv", 0)}, + output_ids={("/Split", 0), ("/Split", 1)}, + ), }, ), ( @@ -147,10 +148,10 @@ def check_bias(model: onnx.ModelProto, ref_biases: dict) -> None: ("/conv_4/Conv", 0): ("/Split", 0), ("/conv_6/Conv", 0): ("/Split", 1), }, - "subgraph_data": { - "subgraph_input_ids": {("/conv_4/Conv", 0)}, - "subgraph_output_ids": {("/Relu_2", 0)}, - }, + "subgraph_data": SubgraphData( + input_ids={("/conv_4/Conv", 0)}, + output_ids={("/Relu_2", 0)}, + ), }, ), ( @@ -160,10 +161,10 @@ def check_bias(model: onnx.ModelProto, ref_biases: dict) -> None: ("/conv_5/Conv", 0): ("/Relu_2", 0), ("/conv_6/Conv", 0): ("/Split", 1), }, - "subgraph_data": { - "subgraph_input_ids": {("/conv_5/Conv", 0), ("/conv_6/Conv", 0)}, - "subgraph_output_ids": {("/Add_3", 0), ("/Concat_1", 0)}, - }, + "subgraph_data": SubgraphData( + input_ids={("/conv_5/Conv", 0), ("/conv_6/Conv", 0)}, + output_ids={("/Add_3", 0), ("/Concat_1", 0)}, + ), }, ), ( @@ -174,14 +175,14 @@ def check_bias(model: onnx.ModelProto, ref_biases: dict) -> None: ("/conv_9/Conv", 0): ("/Add_3", 0), ("/conv_10/Conv", 0): ("/Concat_1", 0), }, - "subgraph_data": { - "subgraph_input_ids": { + "subgraph_data": SubgraphData( + input_ids={ ("/conv_8/Conv", 0), ("/conv_9/Conv", 0), ("/conv_10/Conv", 0), }, - "subgraph_output_ids": {("/Concat_2", 0)}, - }, + output_ids={("/Concat_2", 0)}, + ), }, ), # Disabled, because ONNX backend doesn't support bias correction for MatMul diff --git a/tests/openvino/native/test_bias_correction.py b/tests/openvino/native/test_bias_correction.py index c7b1ae25410..068a880b095 100644 --- a/tests/openvino/native/test_bias_correction.py +++ b/tests/openvino/native/test_bias_correction.py @@ -20,6 +20,7 @@ from nncf.openvino.graph.model_utils import remove_fq_from_inputs from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.graph.node_utils import get_bias_value +from nncf.quantization.algorithms.bias_correction.algorithm import SubgraphData from nncf.quantization.algorithms.bias_correction.openvino_backend import OVBiasCorrectionAlgoBackend from tests.cross_fw.test_templates.helpers import ConvTestModel from tests.cross_fw.test_templates.helpers import DepthwiseConvTestModel @@ -103,10 +104,10 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: ("/Concat", 1): ("input.1", 0), ("/conv_1/Conv/WithoutBiases", 0): ("/Concat", 0), }, - "subgraph_data": { - "subgraph_input_ids": {("/conv_1/Conv/WithoutBiases", 0)}, - "subgraph_output_ids": {("/Split", 0), ("/maxpool_1/MaxPool", 0), ("/Split", 1)}, - }, + "subgraph_data": SubgraphData( + input_ids={("/conv_1/Conv/WithoutBiases", 0)}, + output_ids={("/Split", 0), ("/maxpool_1/MaxPool", 0), ("/Split", 1)}, + ), }, ), ( @@ -118,10 +119,10 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: ("/conv_4/Conv/WithoutBiases", 0): ("/Split", 0), ("/conv_6/Conv/WithoutBiases", 0): ("/Split", 1), }, - "subgraph_data": { - "subgraph_input_ids": {("/conv_2/Conv/WithoutBiases", 0)}, - "subgraph_output_ids": {("/Relu_1", 0)}, - }, + "subgraph_data": SubgraphData( + input_ids={("/conv_2/Conv/WithoutBiases", 0)}, + output_ids={("/Relu_1", 0)}, + ), }, ), ( @@ -134,10 +135,10 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: ("/conv_4/Conv/WithoutBiases", 0): ("/Split", 0), ("/conv_6/Conv/WithoutBiases", 0): ("/Split", 1), }, - "subgraph_data": { - "subgraph_input_ids": {("/conv_1/Conv/WithoutBiases", 0), ("/conv_3/Conv/WithoutBiases", 0)}, - "subgraph_output_ids": {("/Split", 0), ("/Split", 1)}, - }, + "subgraph_data": SubgraphData( + input_ids={("/conv_1/Conv/WithoutBiases", 0), ("/conv_3/Conv/WithoutBiases", 0)}, + output_ids={("/Split", 0), ("/Split", 1)}, + ), }, ), ( @@ -147,10 +148,10 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: ("/conv_4/Conv/WithoutBiases", 0): ("/Split", 0), ("/conv_6/Conv/WithoutBiases", 0): ("/Split", 1), }, - "subgraph_data": { - "subgraph_input_ids": {("/conv_4/Conv/WithoutBiases", 0)}, - "subgraph_output_ids": {("/Relu_2", 0)}, - }, + "subgraph_data": SubgraphData( + input_ids={("/conv_4/Conv/WithoutBiases", 0)}, + output_ids={("/Relu_2", 0)}, + ), }, ), ( @@ -160,10 +161,10 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: ("/conv_5/Conv/WithoutBiases", 0): ("/Relu_2", 0), ("/conv_6/Conv/WithoutBiases", 0): ("/Split", 1), }, - "subgraph_data": { - "subgraph_input_ids": {("/conv_5/Conv/WithoutBiases", 0), ("/conv_6/Conv/WithoutBiases", 0)}, - "subgraph_output_ids": {("/Add_3", 0), ("/Concat_1", 0)}, - }, + "subgraph_data": SubgraphData( + input_ids={("/conv_5/Conv/WithoutBiases", 0), ("/conv_6/Conv/WithoutBiases", 0)}, + output_ids={("/Add_3", 0), ("/Concat_1", 0)}, + ), }, ), ( @@ -174,14 +175,14 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: ("/conv_9/Conv/WithoutBiases", 0): ("/Add_3", 0), ("/conv_10/Conv/WithoutBiases", 0): ("/Concat", 0), }, - "subgraph_data": { - "subgraph_input_ids": { + "subgraph_data": SubgraphData( + input_ids={ ("/conv_8/Conv/WithoutBiases", 0), ("/conv_9/Conv/WithoutBiases", 0), ("/conv_10/Conv/WithoutBiases", 0), }, - "subgraph_output_ids": {("/Concat_2", 0)}, - }, + output_ids={("/Concat_2", 0)}, + ), }, ), ( @@ -190,10 +191,10 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: "collected_inputs": { ("/MatMul", 0): ("/Reshape", 0), }, - "subgraph_data": { - "subgraph_input_ids": {("/MatMul", 0)}, - "subgraph_output_ids": {("/Reshape_1", 0), ("/Add_4", 0)}, - }, + "subgraph_data": SubgraphData( + input_ids={("/MatMul", 0)}, + output_ids={("/Reshape_1", 0), ("/Add_4", 0)}, + ), }, ), ), diff --git a/tests/torch/fx/test_bias_correction.py b/tests/torch/fx/test_bias_correction.py index c8918d622d7..39edd8c074a 100644 --- a/tests/torch/fx/test_bias_correction.py +++ b/tests/torch/fx/test_bias_correction.py @@ -21,6 +21,7 @@ from nncf.experimental.torch.fx.model_utils import remove_fq_from_inputs from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter from nncf.experimental.torch.fx.node_utils import get_bias_value +from nncf.quantization.algorithms.bias_correction.algorithm import SubgraphData from nncf.quantization.algorithms.bias_correction.torch_fx_backend import FXBiasCorrectionAlgoBackend from tests.cross_fw.test_templates.helpers import ConvTestModel from tests.cross_fw.test_templates.helpers import DepthwiseConvTestModel @@ -98,10 +99,10 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: ("concat", 1): ("arg0_1", 0), ("conv2d", 0): ("concat", 0), }, - "subgraph_data": { - "subgraph_input_ids": {("conv2d", 0)}, - "subgraph_output_ids": {("getitem", 0), ("max_pool2d", 0), ("getitem_1", 0)}, - }, + "subgraph_data": SubgraphData( + input_ids={("conv2d", 0)}, + output_ids={("getitem", 0), ("max_pool2d", 0), ("getitem_1", 0)}, + ), }, ), ( @@ -113,10 +114,10 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: ("conv2d_3", 0): ("getitem", 0), ("conv2d_5", 0): ("getitem_1", 0), }, - "subgraph_data": { - "subgraph_input_ids": {("conv2d_1", 0)}, - "subgraph_output_ids": {("relu_1", 0)}, - }, + "subgraph_data": SubgraphData( + input_ids={("conv2d_1", 0)}, + output_ids={("relu_1", 0)}, + ), }, ), ( @@ -129,10 +130,10 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: ("conv2d_3", 0): ("getitem", 0), ("conv2d_5", 0): ("getitem_1", 0), }, - "subgraph_data": { - "subgraph_input_ids": {("conv2d", 0), ("conv2d_2", 0)}, - "subgraph_output_ids": {("getitem", 0), ("getitem_1", 0)}, - }, + "subgraph_data": SubgraphData( + input_ids={("conv2d", 0), ("conv2d_2", 0)}, + output_ids={("getitem", 0), ("getitem_1", 0)}, + ), }, ), ( @@ -142,10 +143,10 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: ("conv2d_3", 0): ("getitem", 0), ("conv2d_5", 0): ("getitem_1", 0), }, - "subgraph_data": { - "subgraph_input_ids": {("conv2d_3", 0)}, - "subgraph_output_ids": {("relu_2", 0)}, - }, + "subgraph_data": SubgraphData( + input_ids={("conv2d_3", 0)}, + output_ids={("relu_2", 0)}, + ), }, ), ( @@ -155,10 +156,10 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: ("conv2d_4", 0): ("relu_2", 0), ("conv2d_5", 0): ("getitem_1", 0), }, - "subgraph_data": { - "subgraph_input_ids": {("conv2d_4", 0), ("conv2d_5", 0)}, - "subgraph_output_ids": {("add_3", 0), ("concat_1", 0)}, - }, + "subgraph_data": SubgraphData( + input_ids={("conv2d_4", 0), ("conv2d_5", 0)}, + output_ids={("add_3", 0), ("concat_1", 0)}, + ), }, ), ( @@ -169,14 +170,14 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: ("conv2d_8", 0): ("add_2", 0), ("conv2d_9", 0): ("concat", 0), }, - "subgraph_data": { - "subgraph_input_ids": { + "subgraph_data": SubgraphData( + input_ids={ ("conv2d_7", 0), ("conv2d_8", 0), ("conv2d_9", 0), }, - "subgraph_output_ids": {("concat_2", 0)}, - }, + output_ids={("concat_2", 0)}, + ), }, ), ( @@ -185,10 +186,10 @@ def check_bias(model: ov.Model, ref_biases: dict) -> None: "collected_inputs": { ("matmul", 0): ("reshape", 0), }, - "subgraph_data": { - "subgraph_input_ids": {("matmul", 0)}, - "subgraph_output_ids": {("reshape_1", 0), ("add_4", 0)}, - }, + "subgraph_data": SubgraphData( + input_ids={("matmul", 0)}, + output_ids={("reshape_1", 0), ("add_4", 0)}, + ), }, ), ),