#include "caffe2/core/operator.h" namespace caffe2 { namespace { template void AdadeltaUpdate( int N, const float* w, const float* g, const float* h, const float* d, const float epsilon, const float decay, const float* lr, float* nw, float* nh, float* nd, Context* /*context*/) { for (int i = 0; i < N; ++i) { float gi = g[i]; float di = d[i]; float hi = nh[i] = decay * h[i] + (1.0f - decay) * gi * gi; float ng = (std::sqrt(di + epsilon) / std::sqrt(hi + epsilon)) * gi; nw[i] = w[i] + lr[0] * ng; nd[i] = decay * di + (1.0f - decay) * ng * ng; } } } // namespace template class AdadeltaOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; AdadeltaOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f), OP_SINGLE_ARG(float, "decay", decay_, 0.95f) {} bool RunOnDevice() override { CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_GRAD).numel()); CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_DELTA).numel()); CAFFE_ENFORCE(Input(GRAD).numel() == Input(PARAM).numel()); CAFFE_ENFORCE_GE(epsilon_, 0.0f); CAFFE_ENFORCE_GT(decay_, 0.0f); CAFFE_ENFORCE_LT(decay_, 1.0f); Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM)); Output(OUTPUT_MOMENT_GRAD)->ResizeLike(Input(MOMENT_GRAD)); Output(OUTPUT_MOMENT_DELTA)->ResizeLike(Input(MOMENT_DELTA)); AdadeltaUpdate( Input(GRAD).numel(), Input(PARAM).template data(), Input(GRAD).template data(), Input(MOMENT_GRAD).template data(), Input(MOMENT_DELTA).template data(), epsilon_, decay_, Input(LR).template data(), Output(OUTPUT_PARAM)->template mutable_data(), Output(OUTPUT_MOMENT_GRAD)->template mutable_data(), Output(OUTPUT_MOMENT_DELTA)->template mutable_data(), &context_); return true; } protected: const float epsilon_; const float decay_; INPUT_TAGS(PARAM, MOMENT_GRAD, MOMENT_DELTA, GRAD, LR); OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_GRAD, OUTPUT_MOMENT_DELTA); }; template class SparseAdadeltaOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; SparseAdadeltaOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f), OP_SINGLE_ARG(float, "decay", decay_, 0.95f) {} bool RunOnDevice() override { // Enforce shapes CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_GRAD).numel()); CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_DELTA).numel()); CAFFE_ENFORCE_EQ(Input(LR).numel(), 1); CAFFE_ENFORCE_EQ( Input(PARAM).size_from_dim(1), Input(GRAD).size_from_dim(Input(INDICES).dim())); // Enforce domain constraints for attributes CAFFE_ENFORCE_GE(epsilon_, 0.0f); CAFFE_ENFORCE_GT(decay_, 0.0f); CAFFE_ENFORCE_LT(decay_, 1.0f); return DispatchHelper>::call( this, Input(INDICES)); } template bool DoRunWithType() { const auto* lr = Input(LR).template data(); const auto* indices = Input(INDICES).template data(); const auto* gradIn = Input(GRAD).template data(); const auto* paramIn = Input(PARAM).template data(); const auto* momentIn = Input(MOMENT_GRAD).template data(); const auto* momentDeltaIn = Input(MOMENT_DELTA).template data(); auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data(); auto* momentOut = Output(OUTPUT_MOMENT_GRAD)->template mutable_data(); auto* momentDeltaOut = Output(OUTPUT_MOMENT_DELTA)->template mutable_data(); auto n = Input(INDICES).numel(); if (n == 0) { return true; } auto block_size = Input(GRAD).numel() / n; for (int i = 0; i < n; ++i) { auto idx = indices[i]; if (block_size == 1) { float gi = gradIn[i]; float di = momentDeltaIn[idx]; float hi = momentOut[idx] = decay_ * momentIn[idx] + (1.0f - decay_) * gi * gi; float ng = (std::sqrt(di + epsilon_) / std::sqrt(hi + epsilon_)) * gi; paramOut[idx] = paramIn[idx] + lr[0] * ng; momentDeltaOut[idx] = decay_ * di + (1.0f - decay_) * ng * ng; } else { auto offsetI = i * block_size; auto offsetIdx = idx * block_size; #ifndef NDEBUG CAFFE_ENFORCE_GE( Input(PARAM).numel(), block_size + offsetIdx, this->debug_def().input(PARAM), ", out of bound, idx:", idx, " for input i:", i, " and block size:", block_size); CAFFE_ENFORCE_GE( Input(GRAD).numel(), block_size + offsetI, this->debug_def().input(GRAD), ", out of bound idx, idx:", idx, " for input i:", i); #endif AdadeltaUpdate( block_size, paramIn + offsetIdx, gradIn + offsetI, momentIn + offsetIdx, momentDeltaIn + offsetIdx, epsilon_, decay_, lr, paramOut + offsetIdx, momentOut + offsetIdx, momentDeltaOut + offsetIdx, &context_); } } return true; } protected: const float epsilon_; const float decay_; INPUT_TAGS(PARAM, MOMENT_GRAD, MOMENT_DELTA, INDICES, GRAD, LR); OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_GRAD, OUTPUT_MOMENT_DELTA); }; } // namespace caffe2