From 67e6a6077d40fb64637ebcf5453ae8be375d7e51 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Fri, 18 Jan 2019 18:46:52 +0300 Subject: [PATCH] Create text graphs for Faster-RCNN from TensorFlow with dilated convolutions --- samples/dnn/tf_text_graph_faster_rcnn.py | 52 ++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/samples/dnn/tf_text_graph_faster_rcnn.py b/samples/dnn/tf_text_graph_faster_rcnn.py index e3d0ad0127..e1dfba9fee 100644 --- a/samples/dnn/tf_text_graph_faster_rcnn.py +++ b/samples/dnn/tf_text_graph_faster_rcnn.py @@ -48,10 +48,42 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath): removeIdentity(graph_def) + nodesToKeep = [] def to_remove(name, op): + if name in nodesToKeep: + return False return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \ (name.startswith('CropAndResize') and op != 'CropAndResize') + # Fuse atrous convolutions (with dilations). + nodesMap = {node.name: node for node in graph_def.node} + for node in reversed(graph_def.node): + if node.op == 'BatchToSpaceND': + del node.input[2] + conv = nodesMap[node.input[0]] + spaceToBatchND = nodesMap[conv.input[0]] + + # Extract paddings + stridedSlice = nodesMap[spaceToBatchND.input[2]] + assert(stridedSlice.op == 'StridedSlice') + pack = nodesMap[stridedSlice.input[0]] + assert(pack.op == 'Pack') + + padNodeH = nodesMap[nodesMap[pack.input[0]].input[0]] + padNodeW = nodesMap[nodesMap[pack.input[1]].input[0]] + padH = int(padNodeH.attr['value']['tensor'][0]['int_val'][0]) + padW = int(padNodeW.attr['value']['tensor'][0]['int_val'][0]) + + paddingsNode = NodeDef() + paddingsNode.name = conv.name + '/paddings' + paddingsNode.op = 'Const' + paddingsNode.addAttr('value', [padH, padH, padW, padW]) + graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode) + nodesToKeep.append(paddingsNode.name) + + spaceToBatchND.input[2] = paddingsNode.name + + removeUnusedNodesAndAttrs(to_remove, graph_def) @@ -225,6 +257,26 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath): detectionOut.addAttr('variance_encoded_in_target', True) graph_def.node.extend([detectionOut]) + def getUnconnectedNodes(): + unconnected = [node.name for node in graph_def.node] + for node in graph_def.node: + for inp in node.input: + if inp in unconnected: + unconnected.remove(inp) + return unconnected + + while True: + unconnectedNodes = getUnconnectedNodes() + unconnectedNodes.remove(detectionOut.name) + if not unconnectedNodes: + break + + for name in unconnectedNodes: + for i in range(len(graph_def.node)): + if graph_def.node[i].name == name: + del graph_def.node[i] + break + # Save as text. graph_def.save(outputPath)