Deconvolution layer from TensorFlow
This commit is contained in:
parent
89172c08a2
commit
54f0616a13
@ -863,6 +863,50 @@ void TFImporter::populateNet(Net dstNet)
|
||||
// one input only
|
||||
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
|
||||
}
|
||||
else if (type == "Conv2DBackpropInput")
|
||||
{
|
||||
// op: "Conv2DBackpropInput"
|
||||
// input: "conv2d_transpose/output_shape"
|
||||
// input: "weights"
|
||||
// input: "input"
|
||||
if (layer.input_size() != 3)
|
||||
CV_Error(Error::StsNotImplemented,
|
||||
"Expected output shape, weights and input nodes");
|
||||
|
||||
layerParams.set("bias_term", false);
|
||||
layerParams.blobs.resize(1);
|
||||
|
||||
StrIntVector next_layers = getNextLayers(net, name, "BiasAdd");
|
||||
if (next_layers.size() == 1)
|
||||
{
|
||||
layerParams.set("bias_term", true);
|
||||
layerParams.blobs.resize(2);
|
||||
|
||||
int weights_layer_index = next_layers[0].second;
|
||||
|
||||
blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]);
|
||||
ExcludeLayer(net, weights_layer_index, 0, false);
|
||||
layers_to_ignore[weights_layer_index] = next_layers[0].first;
|
||||
}
|
||||
|
||||
kernelFromTensor(getConstBlob(layer, value_id, 1), layerParams.blobs[0]);
|
||||
// Swap just numbers of input and output channels.
|
||||
std::swap(layerParams.blobs[0].size[0], layerParams.blobs[0].size[1]);
|
||||
|
||||
const int* kshape = layerParams.blobs[0].size.p;
|
||||
layerParams.set("kernel_h", kshape[2]);
|
||||
layerParams.set("kernel_w", kshape[3]);
|
||||
layerParams.set("num_output", kshape[0]);
|
||||
|
||||
setStrides(layerParams, layer);
|
||||
setPadding(layerParams, layer);
|
||||
|
||||
int id = dstNet.addLayer(name, "Deconvolution", layerParams);
|
||||
layer_id[name] = id;
|
||||
|
||||
// one input only
|
||||
connect(layer_id, dstNet, parsePin(layer.input(2)), id, 0);
|
||||
}
|
||||
else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
|
||||
type == "Relu" || type == "Elu" || type == "Softmax" ||
|
||||
type == "Identity")
|
||||
|
||||
@ -125,4 +125,9 @@ TEST(Test_TensorFlow, pooling)
|
||||
runTensorFlowNet("max_pool_odd_same");
|
||||
}
|
||||
|
||||
TEST(Test_TensorFlow, deconvolution)
|
||||
{
|
||||
runTensorFlowNet("deconvolution");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user