reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#ifndef CAFFE2_SGD_ITER_OP_H_
#define CAFFE2_SGD_ITER_OP_H_
 
#include <limits>
#include <mutex>
 
#include "caffe2/core/blob_serialization.h"
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/stats.h"
 
namespace caffe2 {
 
inline void IncrementIter(TensorCPU* output) {
  CAFFE_ENFORCE_EQ(
      output->numel(),
      1,
      "The output of IterOp exists, but not of the right size.");
  int64_t* iter = output->template mutable_data<int64_t>();
  CAFFE_ENFORCE(*iter >= 0, "Previous iteration number is negative.");
  CAFFE_ENFORCE(
      *iter < std::numeric_limits<int64_t>::max(), "Overflow will happen!");
  (*iter)++;
}
 
// IterOp runs an iteration counter. I cannot think of a case where we would
// need to access the iter variable on device, so this will always produce a
// tensor on the CPU side. If the blob already exists and is a tensor<int64_t>
// object, we will simply increment it (this emulates the case when we want to
// resume training). Otherwise we will have the iter starting with 0.
template <class Context>
class IterOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
 
  IterOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws) {}
 
  bool RunOnDevice() override {
    if (InputSize() == 0) {
      LOG(INFO) << "[Input size is zero]";
      if (!OperatorBase::OutputIsTensorType(0, CPU)) {
        // This is the first run; set the iter to start with 0.
        LOG(ERROR) << "You are using an old definition of IterOp that will "
                      "be deprecated soon. More specifically, IterOp now "
                      "requires an explicit in-place input and output.";
 
        VLOG(1) << "Initializing iter counter.";
        auto* output = OperatorBase::OutputTensor(
            0, {1}, at::dtype<int64_t>().device(CPU));
        output->template mutable_data<int64_t>()[0] = 0;
      }
    }
    IncrementIter(OperatorBase::Output<Tensor>(0, CPU));
    return true;
  }
};
 
template <class Context>
class AtomicIterOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
 
  AtomicIterOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        stats_(std::string("atomic_iter/stats/") + operator_def.input(1)) {}
 
  bool RunOnDevice() override {
    auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(0);
    std::lock_guard<std::mutex> lg(*mutex);
    IncrementIter(OperatorBase::Output<Tensor>(0, CPU));
    CAFFE_EVENT(stats_, num_iter);
    return true;
  }
 
 private:
  struct AtomicIterOpStats {
    CAFFE_STAT_CTOR(AtomicIterOpStats);
    CAFFE_EXPORTED_STAT(num_iter);
  } stats_;
};
 
class MutexSerializer : public BlobSerializerBase {
 public:
  /**
   * Serializes a std::unique_ptr<std::mutex>. Note that this blob has to
   * contain std::unique_ptr<std::mutex>, otherwise this function produces a
   * fatal error.
   */
  void Serialize(
      const void* pointer,
      TypeMeta typeMeta,
      const string& name,
      BlobSerializerBase::SerializationAcceptor acceptor) override;
};
 
class MutexDeserializer : public BlobDeserializerBase {
 public:
  void Deserialize(const BlobProto& proto, Blob* blob) override;
};
 
} // namespace caffe2
 
#endif // CAFFE2_SGD_ITER_OP_H_