#pragma once #include "caffe2/core/operator.h" namespace caffe2 { template void momentum_sgd_update( const int N, const float* g, const float* m, float* ng, float* nm, const float* lr, const float momentum, const bool nesterov, float* param, Context* /*context*/) { const float LR = lr[0]; for (auto i = 0; i < N; ++i) { if (!nesterov) { const float adjusted_gradient = LR * g[i] + momentum * m[i]; nm[i] = adjusted_gradient; ng[i] = adjusted_gradient; } else { const float mi = m[i]; const float mi_new = momentum * mi + LR * g[i]; nm[i] = mi_new; ng[i] = (1 + momentum) * mi_new - momentum * mi; } if (param) { param[i] -= ng[i]; } } } template class MomentumSGDOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; MomentumSGDOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), momentum_(this->template GetSingleArgument("momentum", 0.0)), nesterov_(this->template GetSingleArgument("nesterov", 0)) {} bool RunOnDevice() override { auto device_type = Context::GetDeviceType(); // Iter live on the CPU CAFFE_ENFORCE(OperatorBase::InputIsTensorType(GRAD, device_type)); CAFFE_ENFORCE(OperatorBase::InputIsTensorType(MOMENTUM, device_type)); CAFFE_ENFORCE(Input(LR).numel() == 1); CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENTUM).numel()); Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM)); momentum_sgd_update( Input(GRAD).numel(), Input(GRAD).template data(), Input(MOMENTUM).template data(), Output(OUTPUT_GRAD)->template mutable_data(), Output(OUTPUT_MOMENTUM)->template mutable_data(), Input(LR).template data(), momentum_, nesterov_, NULL, &context_); return true; } protected: T momentum_{0.9}; bool nesterov_; INPUT_TAGS(GRAD, MOMENTUM, LR); OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM); }; template class MomentumSGDUpdateOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; MomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), momentum_(this->template GetSingleArgument("momentum", 0.0)), nesterov_(this->template GetSingleArgument("nesterov", 0)) {} bool RunOnDevice() override { auto device_type = Context::GetDeviceType(); // Iter live on the CPU CAFFE_ENFORCE(OperatorBase::InputIsTensorType(GRAD, device_type)); CAFFE_ENFORCE(OperatorBase::InputIsTensorType(MOMENTUM, device_type)); CAFFE_ENFORCE_EQ(Input(LR).numel(), 1); CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(MOMENTUM).numel()); Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM)); momentum_sgd_update( Input(GRAD).numel(), Input(GRAD).template data(), Input(MOMENTUM).template data(), Output(OUTPUT_GRAD)->template mutable_data(), Output(OUTPUT_MOMENTUM)->template mutable_data(), Input(LR).template data(), momentum_, nesterov_, Output(OUTPUT_PARAM)->template mutable_data(), &context_); return true; } protected: T momentum_{0.9}; bool nesterov_; INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM); OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM); }; template class SparseMomentumSGDUpdateOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; SparseMomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), momentum_(this->template GetSingleArgument("momentum", 0.0)), nesterov_(this->template GetSingleArgument("nesterov", 0)) {} bool RunOnDevice() override { // Resize [potentially] out-of-place blobs Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); // Enforce shapes CAFFE_ENFORCE_EQ(Input(LR).numel(), 1); CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENTUM).numel()); CAFFE_ENFORCE_EQ( Input(PARAM).size_from_dim(1), Input(GRAD).size_from_dim(Input(INDICES).dim())); return DispatchHelper>::call( this, Input(INDICES)); } template bool DoRunWithType() { auto block_size = Input(PARAM).numel() / Input(PARAM).size(0); auto n = Input(GRAD).numel() / block_size; const auto* gradIn = Input(GRAD).template data(); const auto* momentumIn = Input(MOMENTUM).template data(); const auto* lr = Input(LR).template data(); // const auto* paramIn = Input(PARAM).template data(); const auto* indices = Input(INDICES).template data(); auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data(); auto* momentumOut = Output(OUTPUT_MOMENTUM)->template mutable_data(); auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data(); for (auto i = 0; i < n; ++i) { auto idx = indices[i]; auto offsetI = i * block_size; auto offsetIdx = idx * block_size; CAFFE_ENFORCE(offsetIdx + block_size <= Input(PARAM).numel()); CAFFE_ENFORCE(offsetI + block_size <= Input(GRAD).numel()); momentum_sgd_update( block_size, gradIn + offsetI, momentumIn + offsetIdx, gradOut + offsetI, momentumOut + offsetIdx, lr, momentum_, nesterov_, paramOut + offsetIdx, &context_); } return true; } protected: T momentum_; bool nesterov_; INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM, INDICES); OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM); }; }