reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#pragma once
 
#include "caffe2/core/common_omp.h"
#include "caffe2/core/operator.h"
 
namespace caffe2 {
 
template <typename Context>
void rmsprop_update(
    int N,
    const float* g,
    const float* ms,
    const float* mom,
    float* ng,
    float* nms,
    float* nmom,
    float decay,
    float momentum,
    float epsilon,
    const float* lr,
    Context* context);
 
template <typename T, class Context>
class RmsPropOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  RmsPropOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        decay_(this->template GetSingleArgument<float>("decay", 0.9f)),
        momentum_(this->template GetSingleArgument<float>("momentum", 0.0f)),
        epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)) {}
  bool RunOnDevice() override {
    CAFFE_ENFORCE(Input(LR).numel() == 1);
    CAFFE_ENFORCE(Input(GRAD).numel() == Input(MEAN_SQUARES).numel());
    CAFFE_ENFORCE(Input(GRAD).numel() == Input(OUTPUT_MOMENTUM).numel());
    Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
    Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
    Output(OUTPUT_MEAN_SQUARES)->ResizeLike(Input(MEAN_SQUARES));
    Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
    rmsprop_update<Context>(
        Input(GRAD).numel(),
        Input(GRAD).template data<T>(),
        Input(MEAN_SQUARES).template data<T>(),
        Input(MOMENTUM).template data<T>(),
        Output(OUTPUT_GRAD)->template mutable_data<T>(),
        Output(OUTPUT_MEAN_SQUARES)->template mutable_data<T>(),
        Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
        decay_,
        momentum_,
        epsilon_,
        Input(LR).template data<T>(),
        &context_);
    return true;
  }
 
 protected:
  T decay_{0.9};
  T momentum_{0.0};
  T epsilon_{1e-8};
  INPUT_TAGS(GRAD, MEAN_SQUARES, MOMENTUM, LR);
  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MEAN_SQUARES, OUTPUT_MOMENTUM);
};
}