#pragma once #include "caffe2/core/operator.h" namespace caffe2 { template void wngrad_update( int N, const float* w, const float* g, const float* h, float* nw, float* nh, float epsilon, const float* lr, Context* /*context*/) { for (auto i = 0; i < N; ++i) { float gi = g[i]; nw[i] = w[i] + lr[0] * gi / (h[0] + epsilon); } float nhTmp = 0.0; for (auto i = 0; i < N; ++i) { float gi = g[i]; nhTmp += gi * gi; } nhTmp /= (h[0] + epsilon); nh[0] = h[0] + nhTmp; } template void wngrad_update_output_effective_lr( int N, const float* paramIn, const float* gradIn, const float* seqBIn, float* paramOut, float* seqBOut, float* effectiveLROut, float epsilon, const float* lr, Context* /*context*/) { effectiveLROut[0] = lr[0] / (seqBIn[0] + epsilon); float seqBTmp = 0.0; for (auto i = 0; i < N; ++i) { float gi = gradIn[i]; seqBTmp += gi * gi; } seqBTmp /= (seqBIn[0] + epsilon); seqBOut[0] = seqBIn[0] + seqBTmp; for (auto i = 0; i < N; ++i) { float grad = gradIn[i]; paramOut[i] = paramIn[i] + effectiveLROut[0] * grad; } } template void wngrad_update_output_effective_lr_and_update( int N, const float* paramIn, const float* gradIn, const float* seqBIn, float* paramOut, float* seqBOut, float* effectiveLROut, float* updateOut, float epsilon, const float* lr, Context* /*context*/) { effectiveLROut[0] = lr[0] / (seqBIn[0] + epsilon); float seqBTmp = 0.0; for (auto i = 0; i < N; ++i) { float gi = gradIn[i]; seqBTmp += gi * gi; } seqBTmp /= (seqBIn[0] + epsilon); seqBOut[0] = seqBIn[0] + seqBTmp; for (auto i = 0; i < N; ++i) { float grad = gradIn[i]; float update = updateOut[i] = effectiveLROut[0] * grad; paramOut[i] = paramIn[i] + update; } } template class WngradOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; WngradOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), epsilon_(this->template GetSingleArgument("epsilon", 1e-5f)) {} bool RunOnDevice() override { CAFFE_ENFORCE_EQ( Input(GRAD).numel(), Input(PARAM).numel(), "PARAM size: ", Input(PARAM).numel(), ", GRAD size: ", Input(GRAD).numel(), ", SEQ_B size: ", Input(SEQ_B).numel(), ", LR size: ", Input(LR).numel()); Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM)); Output(OUTPUT_SEQ_B)->ResizeLike(Input(SEQ_B)); if (OutputSize() == 2) { wngrad_update( Input(GRAD).numel(), Input(PARAM).template data(), Input(GRAD).template data(), Input(SEQ_B).template data(), Output(OUTPUT_PARAM)->template mutable_data(), Output(OUTPUT_SEQ_B)->template mutable_data(), epsilon_, Input(LR).template data(), &context_); } else if (OutputSize() == 3) { Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(Input(SEQ_B)); wngrad_update_output_effective_lr( Input(GRAD).numel(), Input(PARAM).template data(), Input(GRAD).template data(), Input(SEQ_B).template data(), Output(OUTPUT_PARAM)->template mutable_data(), Output(OUTPUT_SEQ_B)->template mutable_data(), Output(OUTPUT_EFFECTIVE_LR)->template mutable_data(), epsilon_, Input(LR).template data(), &context_); } else { Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(Input(SEQ_B)); Output(OUTPUT_UPDATE)->ResizeLike(Input(GRAD)); wngrad_update_output_effective_lr_and_update( Input(GRAD).numel(), Input(PARAM).template data(), Input(GRAD).template data(), Input(SEQ_B).template data(), Output(OUTPUT_PARAM)->template mutable_data(), Output(OUTPUT_SEQ_B)->template mutable_data(), Output(OUTPUT_EFFECTIVE_LR)->template mutable_data(), Output(OUTPUT_UPDATE)->template mutable_data(), epsilon_, Input(LR).template data(), &context_); } return true; } protected: T epsilon_; INPUT_TAGS(PARAM, SEQ_B, GRAD, LR); OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_SEQ_B, OUTPUT_EFFECTIVE_LR, OUTPUT_UPDATE); }; template class SparseWngradOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; SparseWngradOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), epsilon_(this->template GetSingleArgument("epsilon", 1e-5f)) {} bool RunOnDevice() override { // Enforce shapes CAFFE_ENFORCE_EQ(Input(SEQ_B).numel(), 1); CAFFE_ENFORCE_EQ(Input(LR).numel(), 1); CAFFE_ENFORCE_EQ( Input(PARAM).size_from_dim(1), Input(GRAD).size_from_dim(Input(INDICES).dim())); return DispatchHelper>::call( this, Input(INDICES)); } template bool DoRunWithType() { const auto* lr = Input(LR).template data(); const auto* indices = Input(INDICES).template data(); const auto* gradIn = Input(GRAD).template data(); const auto* paramIn = Input(PARAM).template data(); const auto* seqBIn = Input(SEQ_B).template data(); auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data(); auto* seqBOut = Output(OUTPUT_SEQ_B)->template mutable_data(); auto n = Input(INDICES).numel(); if (n == 0) { return true; } auto block_size = Input(GRAD).numel() / n; for (auto i = 0; i < n; ++i) { auto idx = indices[i]; if (block_size == 1) { float gi = gradIn[i]; paramOut[idx] = paramIn[idx] + lr[0] * gi / (seqBIn[0] + epsilon_); } else { auto offsetI = i * block_size; auto offsetIdx = idx * block_size; #ifndef NDEBUG CAFFE_ENFORCE_GE( Input(PARAM).numel(), block_size + offsetIdx, this->debug_def().input(PARAM), ", out of bound, idx:", idx, " for input i:", i, " and block size:", block_size); CAFFE_ENFORCE_GE( Input(GRAD).numel(), block_size + offsetI, this->debug_def().input(GRAD), ", out of bound idx, idx:", idx, " for input i:", i); #endif for (auto j = 0; j < block_size; ++j) { float gi = gradIn[offsetI + j]; paramOut[offsetIdx + j] = paramIn[offsetIdx + j] + lr[0] * gi / (seqBIn[0] + epsilon_); } } } float seqBTmp = 0.0; for (auto i = 0; i < Input(GRAD).numel(); ++i) { float gi = gradIn[i]; seqBTmp += gi * gi; } seqBTmp /= seqBIn[0]; seqBOut[0] = seqBTmp + seqBIn[0]; return true; } protected: T epsilon_; INPUT_TAGS(PARAM, SEQ_B, INDICES, GRAD, LR); OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_SEQ_B); }; } // namespace caffe2