#pragma once
|
|
#include <cfloat>
|
#include <cmath>
|
#include "caffe2/core/context.h"
|
#include "caffe2/core/operator.h"
|
#include "caffe2/utils/math.h"
|
|
namespace caffe2 {
|
|
template <typename Context>
|
void lr_update(
|
int n,
|
const float* grad,
|
const float* effgrad,
|
const float* lr,
|
float* nlr,
|
float lr_alpha,
|
bool normalized_lr_adaption,
|
Context* /*context*/) {
|
float x = 0;
|
float y = 0, z = 0;
|
const float kEps = 1e-12f;
|
for (auto i = 0; i < n; i++) {
|
x += grad[i] * effgrad[i];
|
if (normalized_lr_adaption) {
|
y += grad[i] * grad[i];
|
z += effgrad[i] * effgrad[i];
|
}
|
}
|
if (normalized_lr_adaption) {
|
y = fmax(std::sqrt(y), kEps);
|
z = fmax(std::sqrt(z), kEps);
|
nlr[0] = lr[0] * (1 - lr_alpha * x / (y * z));
|
} else {
|
nlr[0] = lr[0] - lr_alpha * x;
|
}
|
}
|
|
template <typename T, class Context>
|
class LearningRateAdaptionOp final : public Operator<Context> {
|
public:
|
LearningRateAdaptionOp(const OperatorDef& operator_def, Workspace* ws)
|
: Operator<Context>(operator_def, ws),
|
lr_alpha_(this->template GetSingleArgument<float>("lr_alpha", 0.01f)),
|
normalized_lr_adaption_(this->template GetSingleArgument<bool>(
|
"normalized_lr_adaption",
|
true)) {}
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
bool RunOnDevice() override {
|
CAFFE_ENFORCE(Input(LR).numel() == 1);
|
CAFFE_ENFORCE(Input(GRAD).numel() == Input(EFFGRAD).numel());
|
Output(OUTPUT_LR)->ResizeLike(Input(LR));
|
lr_update<Context>(
|
Input(GRAD).numel(),
|
Input(GRAD).template data<T>(),
|
Input(EFFGRAD).template data<T>(),
|
Input(LR).template data<T>(),
|
Output(OUTPUT_LR)->template mutable_data<T>(),
|
lr_alpha_,
|
normalized_lr_adaption_,
|
&context_);
|
return true;
|
}
|
|
protected:
|
T lr_alpha_{1e-2};
|
bool normalized_lr_adaption_{true};
|
INPUT_TAGS(LR, GRAD, EFFGRAD);
|
OUTPUT_TAGS(OUTPUT_LR);
|
};
|
|
} // namespace caffe2
|