reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
/**
 * Copyright (c) 2016-present, Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
 
#ifndef CAFFE2_OPT_FUSION_H_
#define CAFFE2_OPT_FUSION_H_
 
#include "caffe2/core/workspace.h"
#include "nomnigraph/Representations/NeuralNet.h"
 
namespace caffe2 {
namespace opt {
 
using namespace nom;
 
CAFFE2_API void fuseConvBN(repr::NNModule* nn, caffe2::Workspace* ws);
 
// Generic activation fusion helper.
//
// \tparam OperationT The operator to be fused.
// \tparam ActivationT The activation to be fused.
// \param nn Neural network module to be modified in place
// \param should_fuse Given a conv op, check whether we want to fuse it with
// subsequent relu or not
// \param postprocess Functor to postprocess the conv node,
// attaching additional attributes if necessary
template <typename OperationT, typename ActivationT>
C10_EXPORT void fuseActivation(
    repr::NNModule* nn,
    std::function<bool(const OperationT& conv)> should_fuse,
    std::function<void(repr::NNGraph::NodeRef conv_node)> postprocess) {
  for (auto node_pair : repr::nn::dataIterator<OperationT>(nn->dataFlow)) {
    repr::NNGraph::NodeRef conv_node;
    OperationT* conv;
    std::tie(conv, conv_node) = node_pair;
 
    // Check topological feasibility
    auto conv_outputs = repr::nn::getOutputs(conv_node);
    if (conv_outputs.size() != 1) {
      continue;
    }
    auto conv_output = conv_outputs.front();
 
    auto consumers = repr::nn::getConsumers(conv_output);
    if (consumers.size() != 1) {
      continue;
    }
    if (!repr::nn::is<ActivationT>(consumers.front())) {
      continue;
    }
    auto relu_node = consumers.front();
 
    auto relu_outputs = repr::nn::getOutputs(relu_node);
    if (relu_outputs.size() != 1) {
      continue;
    }
 
    // Check feasibility with application specific logic
    if (!should_fuse(*conv)) {
      continue;
    }
 
    // Ready to fuse
    auto relu_output = relu_outputs.front();
    auto output_tensor = repr::nn::get<repr::Tensor>(relu_output);
    auto output_node = relu_output;
    auto input_tensor =
        repr::nn::get<repr::Tensor>(repr::nn::getInputs(conv_node).front());
 
    // Conv cannot be in-place
    if (output_tensor->getName() != input_tensor->getName()) {
      nn->dataFlow.replaceNode(conv_output, relu_output);
      nn->dataFlow.deleteNode(relu_node);
      nn->dataFlow.deleteNode(conv_output);
    } else {
      nn->dataFlow.replaceNode(relu_output, conv_output);
      output_tensor = repr::nn::get<repr::Tensor>(conv_output);
      output_node = conv_output;
      nn->dataFlow.deleteNode(relu_node);
      nn->dataFlow.deleteNode(relu_output);
    }
 
    // We may have accidentally made the next op in-place
    // In future iterations of transformations this won't be an issue,
    // but current caffe2 predictor usage requires things like
    // external_input and output to be unchanged.
    bool rectify_inplace = false;
    for (auto& consumer : repr::nn::getConsumers(output_node)) {
      for (auto& consumer_output : repr::nn::getOutputs(consumer)) {
        auto co_name = repr::nn::get<repr::Tensor>(consumer_output)->getName();
        if (co_name == output_tensor->getName()) {
          rectify_inplace = true;
        }
      }
    }
    if (rectify_inplace) {
      auto new_output = nn->dataFlow.createNode(
          make_unique<repr::Tensor>(output_tensor->getName() + "_fusion_fix"));
      nn->dataFlow.replaceNode(output_node, new_output);
    }
 
    // Application specific logic for postprocessing the conv node
    postprocess(conv_node);
  }
}
 
} // namespace opt
} // namespace caffe2
 
#endif // CAFFE2_OPT_FUSION_H_