/**
|
* Copyright (c) 2016-present, Facebook, Inc.
|
*
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*/
|
|
#ifndef CAFFE2_OPT_FUSION_H_
|
#define CAFFE2_OPT_FUSION_H_
|
|
#include "caffe2/core/workspace.h"
|
#include "nomnigraph/Representations/NeuralNet.h"
|
|
namespace caffe2 {
|
namespace opt {
|
|
using namespace nom;
|
|
CAFFE2_API void fuseConvBN(repr::NNModule* nn, caffe2::Workspace* ws);
|
|
// Generic activation fusion helper.
|
//
|
// \tparam OperationT The operator to be fused.
|
// \tparam ActivationT The activation to be fused.
|
// \param nn Neural network module to be modified in place
|
// \param should_fuse Given a conv op, check whether we want to fuse it with
|
// subsequent relu or not
|
// \param postprocess Functor to postprocess the conv node,
|
// attaching additional attributes if necessary
|
template <typename OperationT, typename ActivationT>
|
C10_EXPORT void fuseActivation(
|
repr::NNModule* nn,
|
std::function<bool(const OperationT& conv)> should_fuse,
|
std::function<void(repr::NNGraph::NodeRef conv_node)> postprocess) {
|
for (auto node_pair : repr::nn::dataIterator<OperationT>(nn->dataFlow)) {
|
repr::NNGraph::NodeRef conv_node;
|
OperationT* conv;
|
std::tie(conv, conv_node) = node_pair;
|
|
// Check topological feasibility
|
auto conv_outputs = repr::nn::getOutputs(conv_node);
|
if (conv_outputs.size() != 1) {
|
continue;
|
}
|
auto conv_output = conv_outputs.front();
|
|
auto consumers = repr::nn::getConsumers(conv_output);
|
if (consumers.size() != 1) {
|
continue;
|
}
|
if (!repr::nn::is<ActivationT>(consumers.front())) {
|
continue;
|
}
|
auto relu_node = consumers.front();
|
|
auto relu_outputs = repr::nn::getOutputs(relu_node);
|
if (relu_outputs.size() != 1) {
|
continue;
|
}
|
|
// Check feasibility with application specific logic
|
if (!should_fuse(*conv)) {
|
continue;
|
}
|
|
// Ready to fuse
|
auto relu_output = relu_outputs.front();
|
auto output_tensor = repr::nn::get<repr::Tensor>(relu_output);
|
auto output_node = relu_output;
|
auto input_tensor =
|
repr::nn::get<repr::Tensor>(repr::nn::getInputs(conv_node).front());
|
|
// Conv cannot be in-place
|
if (output_tensor->getName() != input_tensor->getName()) {
|
nn->dataFlow.replaceNode(conv_output, relu_output);
|
nn->dataFlow.deleteNode(relu_node);
|
nn->dataFlow.deleteNode(conv_output);
|
} else {
|
nn->dataFlow.replaceNode(relu_output, conv_output);
|
output_tensor = repr::nn::get<repr::Tensor>(conv_output);
|
output_node = conv_output;
|
nn->dataFlow.deleteNode(relu_node);
|
nn->dataFlow.deleteNode(relu_output);
|
}
|
|
// We may have accidentally made the next op in-place
|
// In future iterations of transformations this won't be an issue,
|
// but current caffe2 predictor usage requires things like
|
// external_input and output to be unchanged.
|
bool rectify_inplace = false;
|
for (auto& consumer : repr::nn::getConsumers(output_node)) {
|
for (auto& consumer_output : repr::nn::getOutputs(consumer)) {
|
auto co_name = repr::nn::get<repr::Tensor>(consumer_output)->getName();
|
if (co_name == output_tensor->getName()) {
|
rectify_inplace = true;
|
}
|
}
|
}
|
if (rectify_inplace) {
|
auto new_output = nn->dataFlow.createNode(
|
make_unique<repr::Tensor>(output_tensor->getName() + "_fusion_fix"));
|
nn->dataFlow.replaceNode(output_node, new_output);
|
}
|
|
// Application specific logic for postprocessing the conv node
|
postprocess(conv_node);
|
}
|
}
|
|
} // namespace opt
|
} // namespace caffe2
|
|
#endif // CAFFE2_OPT_FUSION_H_
|