importonnx from onnximporthelper, checker, TensorProto, numpy_helper from collectionsimportdequeimportwarningsimportnumpy as np def fix_split_attributes(model): 将模型中所有 Split 节点的split属性转换为常量输入兼容 opset13。 graphmodel.graph nodeslist(graph.node)modifiedFalse# 收集所有初始化的常量名称initializer_namesset(init.nameforinitingraph.initializer)new_nodes[]fornodeinnodes:ifnode.op_typeSplit:# 检查是否有 split 属性attrs{attr.name: attrforattrinnode.attribute}ifsplitinattrs: split_attrattrs[split]split_valslist(split_attr.ints)# split 属性是 ints 列表iflen(split_vals)0:# 没有指定 split则保持原样这是合法的表示均匀分割new_nodes.append(node)continue# 创建一个常量节点作为 splits 输入splits_tensor_namenode.name _splits_constifnode.nameelsesplit_const_ str(id(node))# 确保名称唯一counter0whilesplits_tensor_nameininitializer_names: splits_tensor_namef{splits_tensor_name}_{counter}counter1# 将 split_vals 转为 numpy 数组再转为 initializersplits_arraynp.array(split_vals,dtypenp.int64)splits_initializernumpy_helper.from_array(splits_array,namesplits_tensor_name)graph.initializer.append(splits_initializer)initializer_names.add(splits_tensor_name)# 修改节点添加 splits 输入删除 split 属性# 原有输入保持不变新输入加在末尾new_inputslist(node.input)[splits_tensor_name]new_nodehelper.make_node(op_typeSplit,inputsnew_inputs,outputslist(node.output),namenode.name,domainnode.domain)# 复制其他属性排除 splitforattrinnode.attribute:ifattr.name!split:new_node.attribute.append(attr)new_nodes.append(new_node)modifiedTrue print(f修复 Split 节点: {node.name or (unnamed)}, splits{split_vals})continue# 非 Split 或无 split 属性的节点直接保留new_nodes.append(node)ifmodified: graph.ClearField(node)graph.node.extend(new_nodes)# 可选对 value_info 进行轻微清理不必须print(Split 属性转换完成。)else: print(没有发现需要修复的 Split 节点。)returnmodel def prune_model_to_outputs(input_model_path, output_model_path, target_output_names,target_opset21): 裁剪模型并修复可能的 Split 属性不兼容问题。# 1. 加载模型modelonnx.load(input_model_path)graphmodel.graph# 2. 构建节点索引映射node_listlist(graph.node)nodes_by_output{}foridx,nodeinenumerate(node_list):foroutinnode.output: nodes_by_output[out]idx# 3. 反向收集必需节点required_indicesset()queuedeque()foroutput_nameintarget_output_names:ifoutput_nameinnodes_by_output: idxnodes_by_output[output_name]ifidx notinrequired_indices: required_indices.add(idx)queue.append(idx)else: warnings.warn(f输出张量 {output_name} 没有找到生产者节点可能已是模型输入。)whilequeue: node_idxqueue.popleft()nodenode_list[node_idx]forinput_nameinnode.input:ifnot input_name:continueifinput_nameinnodes_by_output: producer_idxnodes_by_output[input_name]ifproducer_idx notinrequired_indices: required_indices.add(producer_idx)queue.append(producer_idx)# 4. 构建新节点列表new_nodes[node_list[i]foriinsorted(required_indices)]# 5. 收集使用到的张量名used_tensor_namesset()fornodeinnew_nodes: used_tensor_names.update(node.input)used_tensor_names.update(node.output)used_tensor_names.update(target_output_names)forinpingraph.input: used_tensor_names.add(inp.name)# 6. 保留相关 initializer 和 value_infonew_initializers[initforinitingraph.initializerifinit.nameinused_tensor_names]new_value_info[viforviingraph.value_infoifvi.nameinused_tensor_names]# 7. 构建新输出 value_infodef get_tensor_info(name):forviingraph.value_info:ifvi.namename:returnviforoutingraph.output:ifout.namename:returnoutreturnhelper.make_tensor_value_info(name, TensorProto.FLOAT,[None, None, None, None])new_outputs[get_tensor_info(name)fornameintarget_output_names]# 8. 创建新图new_graphhelper.make_graph(nodesnew_nodes,namegraph.name _pruned,inputslist(graph.input),outputsnew_outputs,initializernew_initializers,value_infonew_value_info)# 9. 创建新模型new_modelhelper.make_model(new_graph,producer_nameonnx_prune_tool)new_model.opset_import.extend(model.opset_import)# 10. 修复 Split 属性print(正在修复 Split 节点的 split 属性...)new_modelfix_split_attributes(new_model)# 11. 降级 opset如果需要current_opsetNoneforimpinnew_model.opset_import:ifimp.domainor imp.domainai.onnx:current_opsetimp.versionbreakifcurrent_opset is None: new_model.opset_import.append(helper.make_opsetid(, target_opset))current_opsettarget_opset print(f当前模型默认 opset: {current_opset})ifcurrent_opsettarget_opset: print(f尝试将 opset 从 {current_opset} 降级到 {target_opset} ...)try:# 使用 version_converter 降级from onnximportversion_converter converted_modelversion_converter.convert_version(new_model, target_opset)new_modelconverted_model print(f成功降级到 opset {target_opset})except Exception as e: warnings.warn(fopset 自动降级失败: {e}\n将强制修改 opset 版本号可能有风险)forimpinnew_model.opset_import:ifimp.domainor imp.domainai.onnx:imp.versiontarget_opset else: print(f当前 opset {current_opset} {target_opset}无需降级。)# 12. 最终验证try: checker.check_model(new_model)print(模型检查通过。)except Exception as e: print(f模型检查失败: {e})print(精简后模型可能无效但将尝试保存。)# 13. 保存onnx.save(new_model, output_model_path)print(f模型已保存至: {output_model_path})print(f原始节点数: {len(node_list)}, 保留节点数: {len(new_nodes)})if__name____main__:input_onnxyolo11s-obb.onnx# 原始模型output_onnxyolo11s-obb_fixed.onnxtarget_outputs[/model.23/Sigmoid_output_0,/model.23/Concat_3_output_0,/model.23/Concat_2_output_0,/model.23/Concat_1_output_0]prune_model_to_outputs(input_onnx, output_onnx, target_outputs,target_opset21)原始导出的onnx只有一个输出包括一些后处理rk需要搞成4个输出输出的节点在代码中后处理放到外面去不在npu做。坑人的是这个搞成4个输出的网上没有一个人说怎么弄。