reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#pragma once
 
#include <cfloat>
#include <cmath>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
 
namespace caffe2 {
 
template <typename Context>
void lr_update(
    int n,
    const float* grad,
    const float* effgrad,
    const float* lr,
    float* nlr,
    float lr_alpha,
    bool normalized_lr_adaption,
    Context* /*context*/) {
  float x = 0;
  float y = 0, z = 0;
  const float kEps = 1e-12f;
  for (auto i = 0; i < n; i++) {
    x += grad[i] * effgrad[i];
    if (normalized_lr_adaption) {
      y += grad[i] * grad[i];
      z += effgrad[i] * effgrad[i];
    }
  }
  if (normalized_lr_adaption) {
    y = fmax(std::sqrt(y), kEps);
    z = fmax(std::sqrt(z), kEps);
    nlr[0] = lr[0] * (1 - lr_alpha * x / (y * z));
  } else {
    nlr[0] = lr[0] - lr_alpha * x;
  }
}
 
template <typename T, class Context>
class LearningRateAdaptionOp final : public Operator<Context> {
 public:
  LearningRateAdaptionOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        lr_alpha_(this->template GetSingleArgument<float>("lr_alpha", 0.01f)),
        normalized_lr_adaption_(this->template GetSingleArgument<bool>(
            "normalized_lr_adaption",
            true)) {}
  USE_OPERATOR_CONTEXT_FUNCTIONS;
 
  bool RunOnDevice() override {
    CAFFE_ENFORCE(Input(LR).numel() == 1);
    CAFFE_ENFORCE(Input(GRAD).numel() == Input(EFFGRAD).numel());
    Output(OUTPUT_LR)->ResizeLike(Input(LR));
    lr_update<Context>(
        Input(GRAD).numel(),
        Input(GRAD).template data<T>(),
        Input(EFFGRAD).template data<T>(),
        Input(LR).template data<T>(),
        Output(OUTPUT_LR)->template mutable_data<T>(),
        lr_alpha_,
        normalized_lr_adaption_,
        &context_);
    return true;
  }
 
 protected:
  T lr_alpha_{1e-2};
  bool normalized_lr_adaption_{true};
  INPUT_TAGS(LR, GRAD, EFFGRAD);
  OUTPUT_TAGS(OUTPUT_LR);
};
 
} // namespace caffe2