#pragma once #include "caffe2/core/common_omp.h" #include "caffe2/core/operator.h" namespace caffe2 { template void rmsprop_update( int N, const float* g, const float* ms, const float* mom, float* ng, float* nms, float* nmom, float decay, float momentum, float epsilon, const float* lr, Context* context); template class RmsPropOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; RmsPropOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), decay_(this->template GetSingleArgument("decay", 0.9f)), momentum_(this->template GetSingleArgument("momentum", 0.0f)), epsilon_(this->template GetSingleArgument("epsilon", 1e-5f)) {} bool RunOnDevice() override { CAFFE_ENFORCE(Input(LR).numel() == 1); CAFFE_ENFORCE(Input(GRAD).numel() == Input(MEAN_SQUARES).numel()); CAFFE_ENFORCE(Input(GRAD).numel() == Input(OUTPUT_MOMENTUM).numel()); Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); Output(OUTPUT_MEAN_SQUARES)->ResizeLike(Input(MEAN_SQUARES)); Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM)); rmsprop_update( Input(GRAD).numel(), Input(GRAD).template data(), Input(MEAN_SQUARES).template data(), Input(MOMENTUM).template data(), Output(OUTPUT_GRAD)->template mutable_data(), Output(OUTPUT_MEAN_SQUARES)->template mutable_data(), Output(OUTPUT_MOMENTUM)->template mutable_data(), decay_, momentum_, epsilon_, Input(LR).template data(), &context_); return true; } protected: T decay_{0.9}; T momentum_{0.0}; T epsilon_{1e-8}; INPUT_TAGS(GRAD, MEAN_SQUARES, MOMENTUM, LR); OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MEAN_SQUARES, OUTPUT_MOMENTUM); }; }