reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-10 c3765bd24fe73747688a0ec2a550f219c9acb384
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
/**
 * Copyright (c) 2016-present, Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
 
#pragma once
 
#include <unordered_map>
 
#include "caffe2/core/common.h"
#include "caffe2/core/event.h"
#include "caffe2/core/net.h"
#include "caffe2/core/observer.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/timer.h"
#include "caffe2/observers/operator_attaching_net_observer.h"
 
namespace caffe2 {
 
/**
 * This observer displays a description of each operator executed in a network.
 * This includes input and tensors (name, size, type), arguments, and execution
 * time. This can be used to analyze different performance characteristics.
 * NOTE: Currently this observer only supports synchronized computation
 **/
 
class ProfileObserver;
class ProfileCounter {
 public:
  explicit ProfileCounter() {}
 
 protected:
  Timer timer_;
  float start_time_ = 0.0f;
  float run_time_ = 0.0f;
};
 
class CAFFE2_API ProfileOperatorObserver final
    : public ProfileCounter,
      public ObserverBase<OperatorBase> {
 public:
  explicit ProfileOperatorObserver(OperatorBase* subject) = delete;
  explicit ProfileOperatorObserver(
      OperatorBase* subject,
      ProfileObserver* netObserver)
      : ObserverBase<OperatorBase>(subject), netObserver_(netObserver) {
    if (subject) {
      net_position_ = subject->net_position();
    }
  }
  explicit ProfileOperatorObserver(
      OperatorBase* subject,
      ProfileObserver* netObserver,
      int net_position,
      int rnn_order)
      : ProfileOperatorObserver(subject, netObserver) {
    net_position_ = net_position;
    rnn_order_ = rnn_order;
  }
 
  std::unique_ptr<ObserverBase<OperatorBase>> rnnCopy(
      OperatorBase* subject,
      int rnn_order) const override;
 
  void Dump() const;
 
  virtual std::string getId() const {
    std::stringstream ss;
    ss << net_position_;
    if (rnn_order_ != OperatorBase::kNoNetPositionSet) {
      ss << "-" << rnn_order_;
    }
    return ss.str();
  }
 
 protected:
  ProfileObserver* netObserver_;
  int net_position_; // Needed because this is not visible in RNN Executor
  int rnn_order_ = OperatorBase::kNoNetPositionSet;
 
 private:
  void Start() override;
  void Stop() override;
};
 
class CAFFE2_API ProfileObserver final : public OperatorAttachingNetObserver<
                                             ProfileOperatorObserver,
                                             ProfileObserver> {
 public:
  explicit ProfileObserver(NetBase* subject)
      : OperatorAttachingNetObserver<ProfileOperatorObserver, ProfileObserver>(
            subject,
            this) {}
 
  void Start() override{};
  void Stop() override{};
 
 private:
  vector<const ProfileOperatorObserver*> operator_observers_;
};
 
} // namespace caffe2