diff --git a/src/qonnx/transformation/batchnorm_to_affine.py b/src/qonnx/transformation/batchnorm_to_affine.py index c89d2bdc..6190f867 100644 --- a/src/qonnx/transformation/batchnorm_to_affine.py +++ b/src/qonnx/transformation/batchnorm_to_affine.py @@ -32,7 +32,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class BatchNormToAffine(Transformation): @@ -89,6 +89,9 @@ def apply(self, model): # create Mul and Add nodes to replace the batchnorm mul_node = oh.make_node("Mul", [bn_input, mul_const.name], [mul_output.name]) add_node = oh.make_node("Add", [mul_output.name, add_const.name], [bn_output]) + # preserve metadata from original batchnorm node + copy_metadata_props(n, mul_node) + copy_metadata_props(n, add_node) # insert where the batchnorm is to preserve topological ordering graph.node.insert(node_ind, mul_node) graph.node.insert(node_ind + 1, add_node) diff --git a/src/qonnx/transformation/bipolar_to_xnor.py b/src/qonnx/transformation/bipolar_to_xnor.py index 37f939a2..0b764ef8 100644 --- a/src/qonnx/transformation/bipolar_to_xnor.py +++ b/src/qonnx/transformation/bipolar_to_xnor.py @@ -36,7 +36,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_datatypes import InferDataTypes from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class ConvertBipolarMatMulToXnorPopcount(Transformation): @@ -132,6 +132,9 @@ def find_prod_mt(x): # create Mul and Add nodes to replace the batchnorm mul_node = oh.make_node("Mul", [xnorpcout.name, mul_const.name], [mul_output.name]) add_node = oh.make_node("Add", [mul_output.name, add_const.name], [mm_output]) + # preserve metadata from original MatMul node + copy_metadata_props(n, mul_node) + copy_metadata_props(n, add_node) # insert where the batchnorm is to preserve topological ordering graph.node.insert(node_ind, mul_node) graph.node.insert(node_ind + 1, add_node) diff --git a/src/qonnx/transformation/change_datalayout.py b/src/qonnx/transformation/change_datalayout.py index 7b73e4bf..62e6140b 100644 --- a/src/qonnx/transformation/change_datalayout.py +++ b/src/qonnx/transformation/change_datalayout.py @@ -30,7 +30,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class ChangeDataLayoutQuantAvgPool2d(Transformation): @@ -78,6 +78,7 @@ def apply(self, model): graph.value_info.append(quantavg_out) quantavg_out = quantavg_out.name inp_trans_node = helper.make_node("Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]) + copy_metadata_props(n, inp_trans_node) quantavg_node = helper.make_node( "QuantAvgPool2d", [inp_trans_out], @@ -90,8 +91,10 @@ def apply(self, model): signed=signed, data_layout="NHWC", ) + copy_metadata_props(n, quantavg_node) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [quantavg_out], [node_output], perm=[0, 3, 1, 2]) + copy_metadata_props(n, out_trans_node) # insert nodes graph.node.insert(node_ind, inp_trans_node) graph.node.insert(node_ind + 1, quantavg_node) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index a00f8a9c..f9ca62bb 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -40,7 +40,7 @@ from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name from qonnx.util.onnx import is_eltwise_optype # Standard ONNX nodes which require a ChannelsLast data format to function properly @@ -96,6 +96,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [eltwise_inp], [new_t_inp], perm=inv_perm) + copy_metadata_props(transpose_node, new_transpose_node) t_shape = np.transpose(np.empty(inp_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp @@ -107,6 +108,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe model.set_initializer(unsqueeze_param_name, np.asarray(list(range(ndim_inp - ndim)), dtype=np.int64)) unsqueeze_out_name = model.make_new_valueinfo_name() new_unsqueeze_node = helper.make_node("Unsqueeze", [eltwise_inp, unsqueeze_param_name], [unsqueeze_out_name]) + copy_metadata_props(eltwise_inp, new_unsqueeze_node) unsqueeze_out_shape = np.expand_dims(np.empty(inp_shape), axis=tuple(range(ndim_inp - ndim))).shape model.set_tensor_shape(unsqueeze_out_name, unsqueeze_out_shape) model.graph.node.append(new_unsqueeze_node) @@ -114,6 +116,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [unsqueeze_out_name], [new_t_inp], perm=inv_perm) + copy_metadata_props(transpose_node, new_transpose_node) t_shape = np.transpose(np.empty(unsqueeze_out_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp @@ -239,6 +242,7 @@ def apply(self, model): # channels last transpose inp_trans_node = helper.make_node("Transpose", [inp], [inp_trans_out], perm=to_channels_last_args(ndim)) graph.node.insert(running_node_index, inp_trans_node) + copy_metadata_props(n, inp_trans_node) running_node_index += 1 # Attach to original node @@ -265,6 +269,7 @@ def apply(self, model): "Transpose", [outp_trans_in], [outp], perm=to_channels_first_args(ndim) ) graph.node.insert(running_node_index, outp_trans_node) + copy_metadata_props(n, outp_trans_node) running_node_index += 1 # Attach to original node @@ -567,7 +572,8 @@ def apply(self, model): axis=1, ) graph.node.insert(node_ind, flat_node) - + copy_metadata_props(n, flat_node) + graph_modified = True else: warnings.warn( diff --git a/src/qonnx/transformation/extract_conv_bias.py b/src/qonnx/transformation/extract_conv_bias.py index bf2cf8b4..34b017bd 100644 --- a/src/qonnx/transformation/extract_conv_bias.py +++ b/src/qonnx/transformation/extract_conv_bias.py @@ -30,6 +30,7 @@ from onnx import helper from qonnx.transformation.base import Transformation +from qonnx.util.basic import copy_metadata_props class ExtractBiasFromConv(Transformation): @@ -75,6 +76,7 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) + copy_metadata_props(n, add_node) graph.node.insert(node_ind, add_node) # Repoint Conv output and remove bias tensor diff --git a/src/qonnx/transformation/extract_quant_scale_zeropt.py b/src/qonnx/transformation/extract_quant_scale_zeropt.py index 58863f08..f76e5555 100644 --- a/src/qonnx/transformation/extract_quant_scale_zeropt.py +++ b/src/qonnx/transformation/extract_quant_scale_zeropt.py @@ -33,6 +33,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.general import GiveUniqueParameterTensors, SortGraph from qonnx.transformation.remove import RemoveIdentityOps +from qonnx.util.basic import copy_metadata_props class ExtractQuantScaleZeroPt(Transformation): @@ -69,6 +70,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_scaled) inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm]) + copy_metadata_props(node, inp_scale_node) graph.node.append(inp_scale_node) # create new Mul node # remove scale from Quant node @@ -87,6 +89,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_zeropt) inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm]) + copy_metadata_props(node, inp_zeropt_node) graph.node.append(inp_zeropt_node) # remove zeropt from Quant node new_zeropt_nm = model.make_new_valueinfo_name() @@ -108,6 +111,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(out_zeropt) out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output]) + copy_metadata_props(node, out_zeropt_node) last_node.output[0] = out_zeropt_nm graph.node.append(out_zeropt_node) # important: when tracking a pointer to newly added nodes, @@ -127,6 +131,7 @@ def apply(self, model: ModelWrapper): last_node.output[0] = out_scale_nm graph.value_info.append(out_scale) out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output]) + copy_metadata_props(node, out_scale_node) graph.node.append(out_scale_node) if extract_scale or extract_zeropt: diff --git a/src/qonnx/transformation/gemm_to_matmul.py b/src/qonnx/transformation/gemm_to_matmul.py index 5396a7d6..245a0a2a 100644 --- a/src/qonnx/transformation/gemm_to_matmul.py +++ b/src/qonnx/transformation/gemm_to_matmul.py @@ -32,7 +32,7 @@ from qonnx.core.datatype import DataType from qonnx.transformation.base import Transformation from qonnx.transformation.remove import RemoveIdentityOps -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class GemmToMatMul(Transformation): @@ -76,6 +76,7 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[0]], [inp_trans_out.name]) + copy_metadata_props(n, inp_trans_node) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[0]) @@ -98,6 +99,7 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[1]], [inp_trans_out.name]) + copy_metadata_props(n, inp_trans_node) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 # Copy over the datatype @@ -109,6 +111,7 @@ def apply(self, model): # Insert MatMul: A * B matMul_node = helper.make_node("MatMul", [n.input[0], n.input[1]], [n.output[0]]) + copy_metadata_props(n, matMul_node) graph.node.insert(running_node_index, matMul_node) matMul_node = graph.node[running_node_index] running_node_index += 1 @@ -144,6 +147,7 @@ def apply(self, model): [act_mul_tensor.name, mul_tensor.name], [n.output[0]], ) + copy_metadata_props(n, mul_node) graph.node.insert(running_node_index, mul_node) mul_node_main_branch = graph.node[running_node_index] running_node_index += 1 @@ -175,6 +179,7 @@ def apply(self, model): [n.input[2], mul_tensor.name], [act_mul_tensor.name], ) + copy_metadata_props(n, mul_node) graph.node.insert(running_node_index, mul_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[2]) @@ -196,7 +201,7 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) - + copy_metadata_props(n, add_node) graph.node.insert(running_node_index, add_node) running_node_index += 1 diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index 81f0b713..f0981b34 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -32,7 +32,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.extract_conv_bias import ExtractBiasFromConv -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name class LowerConvsToMatMul(Transformation): @@ -152,6 +152,7 @@ def apply(self, model): # create new nodes # NCHW -> NHWC inp_trans_node = helper.make_node("Transpose", [cnv_input], [inp_trans_out], perm=[0, 2, 3, 1]) + copy_metadata_props(node, inp_trans_node) nodes_to_insert = [inp_trans_node] if need_im2col: @@ -174,12 +175,15 @@ def apply(self, model): dilations=dilation, ) nodes_to_insert.append(im2col_node) + copy_metadata_props(node, im2col_node) matmul_input = im2col_out if need_im2col else inp_trans_out # do matmul matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out]) + copy_metadata_props(node, matmul_node) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2]) + copy_metadata_props(node, out_trans_node) nodes_to_insert.extend([matmul_node, out_trans_node]) diff --git a/src/qonnx/transformation/qcdq_to_qonnx.py b/src/qonnx/transformation/qcdq_to_qonnx.py index b7e35c0d..b4e18f25 100644 --- a/src/qonnx/transformation/qcdq_to_qonnx.py +++ b/src/qonnx/transformation/qcdq_to_qonnx.py @@ -34,7 +34,7 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.base import Transformation -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name def extract_elem_type(elem_type: int, clip_range=None) -> Tuple[int, int, bool]: @@ -203,6 +203,8 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: rounding_mode="ROUND", # round-to-even signed=signed, ) + # Preserve metadata from all nodes being fused + copy_metadata_props(node, fused_node) model.graph.node.insert(dequant_node_index, fused_node) for node_to_remove in nodes_to_remove: model.graph.node.remove(node_to_remove) diff --git a/src/qonnx/transformation/rebalance_conv.py b/src/qonnx/transformation/rebalance_conv.py index ecb2b5e4..0107a62a 100644 --- a/src/qonnx/transformation/rebalance_conv.py +++ b/src/qonnx/transformation/rebalance_conv.py @@ -31,6 +31,7 @@ from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.base import Transformation +from qonnx.util.basic import copy_metadata_props class RebalanceIm2Col(Transformation): @@ -103,6 +104,7 @@ def apply(self, model): inp_reshape_node = helper.make_node( "Reshape", [node.input[0], inp_shapedata.name], [inp_reshape_out.name] ) + copy_metadata_props(node, inp_reshape_node) graph.node.insert(running_node_index, inp_reshape_node) # rewire Im2Col input node.input[0] = inp_reshape_out.name diff --git a/src/qonnx/transformation/resize_conv_to_deconv.py b/src/qonnx/transformation/resize_conv_to_deconv.py index 0dd40972..7eda4fa7 100644 --- a/src/qonnx/transformation/resize_conv_to_deconv.py +++ b/src/qonnx/transformation/resize_conv_to_deconv.py @@ -33,7 +33,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.general.quant import quant, resolve_rounding_mode from qonnx.transformation.base import Transformation -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name def _weight_convolution(cnv_weights: np.ndarray, scale: int) -> np.ndarray: @@ -242,6 +242,7 @@ def apply(self, model): group=group, dilations=dilation, ) + copy_metadata_props(conv, deconv_node) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name diff --git a/src/qonnx/transformation/subpixel_to_deconv.py b/src/qonnx/transformation/subpixel_to_deconv.py index 3f330c99..73ef3f8f 100644 --- a/src/qonnx/transformation/subpixel_to_deconv.py +++ b/src/qonnx/transformation/subpixel_to_deconv.py @@ -31,7 +31,7 @@ from onnx import helper from qonnx.transformation.base import Transformation -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name def _weight_shuffle(cnv_weights: np.ndarray, block_size: int) -> np.ndarray: @@ -197,6 +197,7 @@ def apply(self, model): group=group, dilations=dilation, ) + copy_metadata_props(n, deconv_node) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index cef4f67b..2752212b 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -360,3 +360,61 @@ def auto_pad_to_explicit_padding(autopad_str, idim_h, idim_w, k_h, k_w, stride_h return [pad_half_large_h, pad_half_large_w, pad_half_small_h, pad_half_small_w] else: raise Exception("Unsupported auto_pad: " + autopad_str) + + +def copy_metadata_props(source_node, target_node, mode="overwrite"): + """Copy metadata properties from source node(s) to target node. + + Parameters + ---------- + source_node : onnx.NodeProto or list of onnx.NodeProto + Source node(s) from which to copy metadata_props. If a list is provided, + metadata from all nodes will be merged into the target node. + target_node : onnx.NodeProto + Target node to which metadata_props will be copied. + mode : str, optional + Mode for handling existing metadata properties in the target node. + Options are: + - "overwrite": Existing properties in the target node will be overwritten + by those from the source node(s) if they share the same key. + - "keep_existing": Existing properties in the target node will be kept, + and only new properties from the source node(s) will be added. + Default is "overwrite". + + Returns + ------- + None + Modifies target_node in place by extending its metadata_props. + + Examples + -------- + >>> # Copy from single node + >>> copy_metadata_props(old_node, new_node) + >>> + >>> # Copy from multiple nodes (e.g., when fusing) + >>> copy_metadata_props([quant_node, dequant_node], fused_node) + """ + assert mode in ["overwrite", "keep_existing"], "Copy Metadata Mode must be either 'overwrite' or 'keep_existing'." + + # Handle both single node and list of nodes + source_nodes = source_node if isinstance(source_node, list) else [source_node] + + for node in source_nodes: + if hasattr(node, "metadata_props"): + + # check for existing keys in target_node to avoid duplicates + if hasattr(target_node, "metadata_props"): + existing_keys = {prop.key for prop in target_node.metadata_props} + else: + existing_keys = set() + + for prop in node.metadata_props: + if prop.key in existing_keys: + if mode == "overwrite": + # Overwrite existing metadata property + for existing_prop in target_node.metadata_props: + if existing_prop.key == prop.key: + existing_prop.value = prop.value + break + else: + target_node.metadata_props.append(prop) \ No newline at end of file diff --git a/tests/util/test_copy_metadata.py b/tests/util/test_copy_metadata.py new file mode 100644 index 00000000..1cc913b9 --- /dev/null +++ b/tests/util/test_copy_metadata.py @@ -0,0 +1,116 @@ +import pytest + +import onnx +import onnxscript +from onnxscript import FLOAT +from onnxscript import opset17 as op +from onnxscript import script +from onnxscript.ir.passes.common import LiftConstantsToInitializersPass + +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.gemm_to_matmul import GemmToMatMul +from qonnx.util.basic import copy_metadata_props + + +def add_metadata(key, value): + return onnx.StringStringEntryProto(key=key, value=value) + + +def test_copy_metadata_props(): + # Create source node with metadata + src_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")]) + dst_node = onnx.NodeProto() + + copy_metadata_props(src_node, dst_node) + + assert len(dst_node.metadata_props) == 2 + assert dst_node.metadata_props[0].key == "key1" + assert dst_node.metadata_props[0].value == "value1" + assert dst_node.metadata_props[1].key == "key2" + assert dst_node.metadata_props[1].value == "value2" + + +@pytest.mark.parametrize("mode", ["keep_existing", "overwrite"]) +def test_copy_metadata_props_existing_target_md(mode): + # Create source node with metadata + src_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")]) + # Create destination node with existing metadata + dst_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value3")]) + + copy_metadata_props(src_node, dst_node, mode=mode) + + assert len(dst_node.metadata_props) == 2 + assert dst_node.metadata_props[0].key == "key1" + + if mode == "keep_existing": + assert dst_node.metadata_props[0].value == "value3" # Should keep existing + elif mode == "overwrite": + assert dst_node.metadata_props[0].value == "value1" # Should be overwritten + + assert dst_node.metadata_props[1].key == "key2" + assert dst_node.metadata_props[1].value == "value2" + + +def test_copy_metadata_props_bad_mode(): + src_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value1")]) + dst_node = onnx.NodeProto() + + with pytest.raises(AssertionError): + copy_metadata_props(src_node, dst_node, mode="invalid_mode") + + +def test_copy_metadata_props_gemm2matmul(): + @script() + def MyGemm(A: FLOAT[4, 5], B: FLOAT[5, 4], C: FLOAT[4, 4]) -> FLOAT[4, 4]: + return op.Gemm(A, B, C) + + model_proto = MyGemm.to_model_proto() + gemm_node = model_proto.graph.node[0] + gemm_node.metadata_props.extend([add_metadata("key1", "value1"), add_metadata("key2", "value2")]) + + # Create Model Wrapper + mw = ModelWrapper(model_proto) + + transformed_mw = mw.transform(GemmToMatMul()) + + for node in transformed_mw.graph.node: + assert node.metadata_props[0].key == "key1" + assert node.metadata_props[0].value == "value1" + assert node.metadata_props[1].key == "key2" + assert node.metadata_props[1].value == "value2" + + +def test_copy_metadata_props_batchnorm2affine(): + @script() + def MyBatchNorm(X: FLOAT[1, 3, 4, 4]) -> FLOAT[1, 3, 4, 4]: + scale = op.Constant(value=[[1.0, 1.0, 1.0]]) + B = op.Constant(value=[[0.0, 0.0, 0.0]]) + var = op.Constant(value=[[1.0, 1.0, 1.0]]) + mean = op.Constant(value=[[0.0, 0.0, 0.0]]) + return op.BatchNormalization(X, scale, B, mean, var, epsilon=1e-5, momentum=0.9) + + # remove cast-like nodes + model_proto = onnxscript.optimizer.optimize(MyBatchNorm.to_model_proto()) + + # batchnorm_to_affine requires initializers for scale/mean/var/bias + model_ir = onnxscript.ir.serde.deserialize_model(model_proto) + pass_ = LiftConstantsToInitializersPass(lift_all_constants=True, size_limit=1) + PassResult = pass_.call(model_ir) + model_proto = onnxscript.ir.serde.serialize_model(PassResult.model) + + # Add metadata to BatchNorm node + bn_node = model_proto.graph.node[0] + bn_node.metadata_props.extend([add_metadata("key1", "value1"), add_metadata("key2", "value2")]) + + # Create Model Wrapper + mw = ModelWrapper(model_proto) + from qonnx.transformation.batchnorm_to_affine import BatchNormToAffine + + transformed_mw = mw.transform(BatchNormToAffine()) + + # Check that metadata was copied + for node in transformed_mw.graph.node: + assert node.metadata_props[0].key == "key1" + assert node.metadata_props[0].value == "value1" + assert node.metadata_props[1].key == "key2" + assert node.metadata_props[1].value == "value2"