#pragma once #include "caffe2/core/operator.h" #include "caffe2/core/timer.h" namespace caffe2 { template void fp32_momentum_sgd_update( int N, const float* g, const float* m, float* ng, float* nm, const float* lr, float momentum, bool nesterov, float weight_decay, float* param, Context* /*context*/) {} template class FP32MomentumSGDUpdateOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; FP32MomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), momentum_(this->template GetSingleArgument("momentum", 0.0)), weight_decay_( this->template GetSingleArgument("weight_decay", 0.0)), nesterov_(this->template GetSingleArgument("nesterov", 0)) {} bool RunOnDevice() override { auto device_type = Context::GetDeviceType(); // Iter live on the CPU CAFFE_ENFORCE(OperatorBase::InputIsTensorType(GRAD, device_type)); CAFFE_ENFORCE(OperatorBase::InputIsTensorType(MOMENTUM, device_type)); CAFFE_ENFORCE(Input(LR).size() == 1); CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENTUM).size()); Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM)); fp32_momentum_sgd_update( Input(GRAD).size(), Input(GRAD).template data(), Input(MOMENTUM).template data(), Output(OUTPUT_GRAD)->template mutable_data(), Output(OUTPUT_MOMENTUM)->template mutable_data(), Input(LR).template data(), momentum_, nesterov_, weight_decay_, Output(OUTPUT_PARAM)->template mutable_data(), &context_); return true; } protected: float momentum_{0.9}; float weight_decay_{0.0}; bool nesterov_; INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM); OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM); }; }