From c40fbad12e6d307a694bd55726478e7c4bf10994 Mon Sep 17 00:00:00 2001 From: Lorenzo Lucignano Date: Wed, 20 Nov 2019 10:45:57 +0100 Subject: [PATCH] Samples DNN: tf_text_graph_sd.py loads box coder variance and box NMS params from config file --- samples/dnn/tf_text_graph_ssd.py | 33 +++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/samples/dnn/tf_text_graph_ssd.py b/samples/dnn/tf_text_graph_ssd.py index e6017b227e..905f751557 100644 --- a/samples/dnn/tf_text_graph_ssd.py +++ b/samples/dnn/tf_text_graph_ssd.py @@ -283,6 +283,9 @@ def createSSDGraph(modelPath, configPath, outputPath): # Add layers that generate anchors (bounding boxes proposals). priorBoxes = [] + boxCoder = config['box_coder'][0] + fasterRcnnBoxCoder = boxCoder['faster_rcnn_box_coder'][0] + boxCoderVariance = [1.0/float(fasterRcnnBoxCoder['x_scale'][0]), 1.0/float(fasterRcnnBoxCoder['y_scale'][0]), 1.0/float(fasterRcnnBoxCoder['width_scale'][0]), 1.0/float(fasterRcnnBoxCoder['height_scale'][0])] for i in range(num_layers): priorBox = NodeDef() priorBox.name = 'PriorBox_%d' % i @@ -303,7 +306,7 @@ def createSSDGraph(modelPath, configPath, outputPath): priorBox.addAttr('width', widths) priorBox.addAttr('height', heights) - priorBox.addAttr('variance', [0.1, 0.1, 0.2, 0.2]) + priorBox.addAttr('variance', boxCoderVariance) graph_def.node.extend([priorBox]) priorBoxes.append(priorBox.name) @@ -336,11 +339,31 @@ def createSSDGraph(modelPath, configPath, outputPath): detectionOut.addAttr('num_classes', num_classes + 1) detectionOut.addAttr('share_location', True) detectionOut.addAttr('background_label_id', 0) - detectionOut.addAttr('nms_threshold', 0.6) - detectionOut.addAttr('top_k', 100) + + postProcessing = config['post_processing'][0] + batchNMS = postProcessing['batch_non_max_suppression'][0] + + if 'iou_threshold' in batchNMS: + detectionOut.addAttr('nms_threshold', float(batchNMS['iou_threshold'][0])) + else: + detectionOut.addAttr('nms_threshold', 0.6) + + if 'score_threshold' in batchNMS: + detectionOut.addAttr('confidence_threshold', float(batchNMS['score_threshold'][0])) + else: + detectionOut.addAttr('confidence_threshold', 0.01) + + if 'max_detections_per_class' in batchNMS: + detectionOut.addAttr('top_k', int(batchNMS['max_detections_per_class'][0])) + else: + detectionOut.addAttr('top_k', 100) + + if 'max_total_detections' in batchNMS: + detectionOut.addAttr('keep_top_k', int(batchNMS['max_total_detections'][0])) + else: + detectionOut.addAttr('keep_top_k', 100) + detectionOut.addAttr('code_type', "CENTER_SIZE") - detectionOut.addAttr('keep_top_k', 100) - detectionOut.addAttr('confidence_threshold', 0.01) graph_def.node.extend([detectionOut])