打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
onnx模型如何增加或者去除里面node,即修改图方法

有时候我们通过pytorch导出onnx模型,需要修改一下onnx的图结构,怎么修改呢?

下面两个Python实例提供了修改思路。
Changing the graph is easier than recreating it with make_graph, just use append, remove and insert.参考https://github.com/onnx/onnx/issues/2259

onnx_model = onnx.load(onnxfile)graph = onnx_model.graphpads = onnx.helper.make_tensor('avg_pads', onnx.TensorProto.INT64, [8], np.zeros(8, dtype=int))graph.initializer.append(pads)node = graph.node[584]new_node = onnx.helper.make_node(    'Pad',    name='__Pad_584_fixed',    inputs=['675', 'avg_pads'],    outputs=['676'],    mode='constant')graph.node.remove(node)graph.node.insert(584, new_node)# Fix Equals (replace with Not)node = graph.node[322]new_node = onnx.helper.make_node(    'Not',    name='__Not__Equal_322',    inputs=['412'],    outputs=['414'],)graph.node.remove(node)graph.node.insert(322, new_node)onnx.checker.check_model(onnx_model)onnx.save(onnx_model, onnxfile)

来源:https://github.com/saurabh-shandilya/onnx-utils

# ------------------------------------------------# ONNX Model Editor and Graph Extractor# License under The MIT License# Written by Saurabh Shandilya# -----------------------------------------------import onnxfrom onnx import helper, checkerfrom onnx import TensorProtoimport reimport argparsedef createGraphMemberMap(graph_member_list):    member_map=dict();    for n in graph_member_list:        member_map[n.name]=n;    return member_mapdef split_io_list(io_list,new_names_all):    #splits input/output list to identify removed, retained and totally new nodes        removed_names=[]    retained_names=[]    for n in io_list:        if n.name not in new_names_all:                            removed_names.append(n.name)                      if n.name in new_names_all:                            retained_names.append(n.name)                          new_names=list(set(new_names_all)-set(retained_names))     return [removed_names,retained_names,new_names]          def traceDependentNodes(graph,name,node_input_names,node_map, initializer_map):    # recurisvely traces all dependent nodes for a given output nodes in a graph        for n in graph.node:        for noutput in n.output:                   if (noutput == name) and (n.name not in node_input_names):                # give node "name" is node n's output, so add node "n" to node_input_names list                 node_input_names.append(n.name)                if n.name in node_map.keys():                    for ninput in node_map[n.name].input:                        # trace input node's inputs                         node_input_names = traceDependentNodes(graph,ninput,node_input_names,node_map, initializer_map)                                            # don't forget the initializers they can be terminal inputs on a path.                        if name in initializer_map.keys():        node_input_names.append(name)                        return node_input_names         def onnx_edit(input_model, output_model, new_input_node_names, input_shape_map, new_output_node_names, output_shape_map, verify):    """ edits and modifies an onnx model to extract a subgraph based on input/output node names and shapes.    Arguments:         input_model: path of input onnx model        output_model: path of output onnx model            new_input_node_names: list of input node names including list of original input nodes if they are to be retained.            If the list is empty original input nodes are assumed.         input_shape_map: dictionary/map of input node names to corresponding shapes. Shapes are needed for model checker to pass.        new_output_node_names: list of output node names, including list of original output nodes if they are to be retained            If the list if empty original output nodes are assumed.        output_shape_map: dictionary/map of output node names to corresponding shape. Shapes are needed for model checker to pass.        verify: set to true if input and output models need to be verified.    """    # LOAD MODEL AND PREP MAPS    model = onnx.load(input_model)    graph = model.graph    if(verify):        print("input model Errors: ", onnx.checker.check_model(model))        node_map = createGraphMemberMap(graph.node)    input_map = createGraphMemberMap(graph.input)    output_map = createGraphMemberMap(graph.output)    initializer_map = createGraphMemberMap(graph.initializer)           if not new_input_node_names:        new_input_node_names = list(input_map)    if not new_output_node_names:        new_output_node_names = list(output_map)           # MODIFY INPUTS    # Break the graph based on the new input node names    [removed_names,retained_names,new_names]=split_io_list(graph.input,new_input_node_names)    for name in removed_names:        if name in input_map.keys():            graph.input.remove(input_map[name])                                  for name in new_names:        # If a new input name corresponds to an existing node, it implies that original node in the graph needs to be replaced with an input node        # Exactly here the graph is broken        if name in node_map.keys():            graph.node.remove(node_map[name])        if(name in input_shape_map.keys()):            new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, input_shape_map[name])        else:            new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, None)            graph.input.extend([new_nv])    node_map = createGraphMemberMap(graph.node)    input_map = createGraphMemberMap(graph.input)        # MODIFY OUTPUTS    # Break the graph based on the new output node names       [removed_names,retained_names,new_names]=split_io_list(graph.output,new_output_node_names)    for name in removed_names:        if name in output_map.keys():            graph.output.remove(output_map[name])                                  for name in new_names:        if(name in output_shape_map.keys()):            new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, output_shape_map[name])        else:            new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, None)        graph.output.extend([new_nv])    output_map = createGraphMemberMap(graph.output)          # CLEANUP NODES    # Trace all dependent nodes for the current set of output nodes defined & prepare a list of invalid nodes    valid_node_names=[]    for new_output_node_name in new_output_node_names:        valid_node_names=traceDependentNodes(graph,new_output_node_name,valid_node_names,node_map, initializer_map)        valid_node_names=list(set(valid_node_names))    invalid_node_names = list( (set(node_map.keys()) | set(initializer_map.keys())) - set(valid_node_names))    # Remove all the invalid nodes from the graph                   for name in invalid_node_names:        if name in node_map.keys():            graph.node.remove(node_map[name])                if name in initializer_map.keys():            graph.initializer.remove(initializer_map[name])        if name in input_map.keys():            graph.input.remove(input_map[name])        # SAVE MODEL    if(verify):            print("output model Errors: ", onnx.checker.check_model(model))    onnx.save(model, output_model)def parse_nodename_and_shape(name):    # parses node names and shapes from input argument string    inputs = []    shapes = {}    # input takes in most cases the format name:0, where 0 is the output number, and shapes    # are appended to the same e.g. name:0[1,28,28,3]    name_pattern = r"(?:([\w\d/\-\._:] )(\[[\-\d,] \])?),?"        splits = re.split(name_pattern, name)    for i in range(1, len(splits), 3):                inputs.append(splits[i])        if splits[i   1] is not None:            shapes[splits[i]] = [int(n) for n in splits[i   1][1:-1].split(",")]    if not shapes:        shapes = None    return inputs, shapes                if __name__ == "__main__":    parser = argparse.ArgumentParser()    parser.add_argument("input", help="input onnx model")    parser.add_argument("output", help="output onnx model")    parser.add_argument("--inputs", help="comma separated model input names appended with shapes, e.g. --inputs <nodename>[1,2,3],<nodename1>[1,2,3] ")    parser.add_argument("--outputs", help="comma separated model output names appended with shapes, e.g. --outputs <nodename>[1,2,3],<nodename1>[1,2,3] ")        parser.add_argument('--skipverify', dest='skipverify', action='store_true',                    help='skip verification of model. Useful if shapes are not known')    args = parser.parse_args()            if args.inputs:        new_input_node_names, input_shape_map = parse_nodename_and_shape(args.inputs)        #print(new_input_node_names)        #print(input_shape_map)    else:         new_input_node_names = []        input_shape_map = {}            if args.outputs:        new_output_node_names, output_shape_map = parse_nodename_and_shape(args.outputs)        #print(new_output_node_names)        #print(output_shape_map)    else:        new_output_node_names = []        output_shape_map = {}            onnx_edit(args.input,args.output,new_input_node_names, input_shape_map, new_output_node_names, output_shape_map, not args.skipverify)            
来源:https://www.icode9.com/content-1-652701.html
本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
Functions and Relations
万字长文,一文搞懂Torch转换ONNX详细流程
Py之tf2onnx:tf2onnx库的简介、安装、使用方法之详细攻略
02select监听服务端
模型部署入门教程:ONNX 模型的修改与调试
Calculus on Computational Graphs: Backpropagation
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服