#pragma once #include "caffe2/core/operator.h" #include "caffe2/core/timer.h" namespace caffe2 { template void fp16_momentum_sgd_update( int N, const at::Half* g, const at::Half* m, at::Half* ng, at::Half* nm, const float* lr, float momentum, bool nesterov, float weight_decay, bool fp32_update, at::Half* param, Context* /*context*/) {} template class FP16MomentumSGDUpdateOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; FP16MomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), momentum_(this->template GetSingleArgument("momentum", 0.0)), weight_decay_( this->template GetSingleArgument("weight_decay", 0.0)), nesterov_(this->template GetSingleArgument("nesterov", 0)), // when set, fp32_update will read in the fp16 data but // perform all the compute in fp32 precision. fp32_update_(this->template GetSingleArgument("fp32_update", 0)) {} bool RunOnDevice() override { auto device_type = Context::GetDeviceType(); // Iter live on the CPU CAFFE_ENFORCE(OperatorBase::InputIsTensorType(GRAD, device_type)); CAFFE_ENFORCE(OperatorBase::InputIsTensorType(MOMENTUM, device_type)); CAFFE_ENFORCE(Input(LR).size() == 1); CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENTUM).size()); Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM)); fp16_momentum_sgd_update( Input(GRAD).size(), Input(GRAD).template data(), Input(MOMENTUM).template data(), Output(OUTPUT_GRAD)->template mutable_data(), Output(OUTPUT_MOMENTUM)->template mutable_data(), Input(LR).template data(), momentum_, nesterov_, weight_decay_, fp32_update_, Output(OUTPUT_PARAM)->template mutable_data(), &context_); return true; } protected: float momentum_{0.9}; float weight_decay_{0.0}; bool nesterov_; bool fp32_update_; INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM); OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM); }; }