#pragma once #include "caffe2/core/operator.h" namespace caffe2 { template void adam_update( int N, const float* g, const float* m, const float* v, float* ng, float* nm, float* nv, float beta1, float beta2, float eps_hat, float correction, const float* lr, Context* /*context*/) { for (auto i = 0; i < N; ++i) { float gi = g[i]; float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1); float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2); ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat); } } template void adam_compute( int N, const float* w, const float* g, const float* m, const float* v, float* nw, float* nm, float* nv, float beta1, float beta2, float eps_hat, float correction, const float* lr, Context* /*context*/) { for (auto i = 0; i < N; ++i) { float gi = g[i]; float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1); float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2); nw[i] = w[i] + lr[0] * correction * mi / (std::sqrt(vi) + eps_hat); } } template void adam_compute_output_grad( int N, const float* w, const float* g, const float* m, const float* v, float* nw, float* nm, float* nv, float* ng, float beta1, float beta2, float eps_hat, float correction, const float* lr, Context* /*context*/) { for (auto i = 0; i < N; ++i) { float gi = g[i]; float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1); float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2); float ngi = ng[i] = correction * mi / (std::sqrt(vi) + eps_hat); nw[i] = w[i] + lr[0] * ngi; } } template class AdamOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; AdamOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), beta1_(this->template GetSingleArgument("beta1", 0.9f)), beta2_(this->template GetSingleArgument("beta2", 0.999f)), epsilon_(this->template GetSingleArgument("epsilon", 1e-5f)) {} bool RunOnDevice() override { // Iter live on the CPU CAFFE_ENFORCE(OperatorBase::InputIsTensorType(ITER, CPU)); CAFFE_ENFORCE(Input(LR).numel() == 1); CAFFE_ENFORCE(Input(GRAD).numel() == Input(PARAM).numel()); CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_1).numel()); CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_2).numel()); Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM)); Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1)); Output(OUTPUT_MOMENT_2)->ResizeLike(Input(MOMENT_2)); const auto iter = OperatorBase::Input(ITER, CPU).template data()[0]; const auto t = iter + 1; const auto correction = std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t)); if (OutputSize() == 3) { adam_compute( Input(GRAD).numel(), Input(PARAM).template data(), Input(GRAD).template data(), Input(MOMENT_1).template data(), Input(MOMENT_2).template data(), Output(OUTPUT_PARAM)->template mutable_data(), Output(OUTPUT_MOMENT_1)->template mutable_data(), Output(OUTPUT_MOMENT_2)->template mutable_data(), beta1_, beta2_, epsilon_, correction, Input(LR).template data(), &context_); } else { Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); adam_compute_output_grad( Input(GRAD).numel(), Input(PARAM).template data(), Input(GRAD).template data(), Input(MOMENT_1).template data(), Input(MOMENT_2).template data(), Output(OUTPUT_PARAM)->template mutable_data(), Output(OUTPUT_MOMENT_1)->template mutable_data(), Output(OUTPUT_MOMENT_2)->template mutable_data(), Output(OUTPUT_GRAD)->template mutable_data(), beta1_, beta2_, epsilon_, correction, Input(LR).template data(), &context_); } return true; } protected: T beta1_{0.9}; T beta2_{0.999}; T epsilon_{1e-8}; INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, GRAD, LR, ITER); OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD); }; template class SparseAdamOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; SparseAdamOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), beta1_(this->template GetSingleArgument("beta1", 0.9f)), beta2_(this->template GetSingleArgument("beta2", 0.999f)), epsilon_(this->template GetSingleArgument("epsilon", 1e-5f)) {} bool RunOnDevice() override { // Enforce shapes CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_1).numel()); CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_2).numel()); CAFFE_ENFORCE_EQ( Input(PARAM).size_from_dim(1), Input(GRAD).size_from_dim(Input(INDICES).dim())); CAFFE_ENFORCE_EQ(Input(LR).numel(), 1); return DispatchHelper>::call( this, Input(INDICES)); } template bool DoRunWithType() { const auto* lr = Input(LR).template data(); const auto iter = OperatorBase::Input(ITER, CPU).template data()[0]; const auto t = iter + 1; const auto correction = std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t)); auto block_size = Input(PARAM).numel() / Input(PARAM).size(0); auto n = Input(GRAD).numel() / block_size; const auto* paramIn = Input(PARAM).template data(); const auto* indices = Input(INDICES).template data(); const auto* gradIn = Input(GRAD).template data(); const auto* moment1In = Input(MOMENT_1).template data(); const auto* moment2In = Input(MOMENT_2).template data(); auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data(); auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data(); auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data(); if (OutputSize() == 3) { for (auto i = 0; i < n; ++i) { auto idx = indices[i]; if (block_size == 1) { float gi = gradIn[i]; float mi = moment1Out[idx] = moment1In[idx] * beta1_ + gi * (1 - beta1_); float vi = moment2Out[idx] = moment2In[idx] * beta2_ + gi * gi * (1 - beta2_); paramOut[idx] = paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_); } else { auto offsetI = i * block_size; auto offsetIdx = idx * block_size; #ifndef NDEBUG CAFFE_ENFORCE_GE( Input(PARAM).numel(), block_size + offsetIdx, this->debug_def().input(PARAM), ", out of bound, idx:", idx, " for input i:", i, " and block size:", block_size); CAFFE_ENFORCE_GE( Input(GRAD).numel(), block_size + offsetI, this->debug_def().input(GRAD), ", out of bound idx, idx:", idx, " for input i:", i); #endif adam_compute( block_size, paramIn + offsetIdx, gradIn + offsetI, moment1In + offsetIdx, moment2In + offsetIdx, paramOut + offsetIdx, moment1Out + offsetIdx, moment2Out + offsetIdx, beta1_, beta2_, epsilon_, correction, lr, &context_); } } } else { Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data(); for (auto i = 0; i < n; ++i) { auto idx = indices[i]; if (block_size == 1) { float gi = gradIn[i]; float mi = moment1Out[idx] = moment1In[idx] * beta1_ + gi * (1 - beta1_); float vi = moment2Out[idx] = moment2In[idx] * beta2_ + gi * gi * (1 - beta2_); float ngi = gradOut[i] = correction * mi / (std::sqrt(vi) + epsilon_); paramOut[idx] = paramIn[idx] + lr[0] * ngi; } else { auto offsetI = i * block_size; auto offsetIdx = idx * block_size; #ifndef NDEBUG CAFFE_ENFORCE_GE( Input(PARAM).numel(), block_size + offsetIdx, this->debug_def().input(PARAM), ", out of bound, idx:", idx, " for input i:", i, " and block size:", block_size); CAFFE_ENFORCE_GE( Input(GRAD).numel(), block_size + offsetI, this->debug_def().input(GRAD), ", out of bound idx, idx:", idx, " for input i:", i); #endif adam_compute_output_grad( block_size, paramIn + offsetIdx, gradIn + offsetI, moment1In + offsetIdx, moment2In + offsetIdx, paramOut + offsetIdx, moment1Out + offsetIdx, moment2Out + offsetIdx, gradOut + offsetI, beta1_, beta2_, epsilon_, correction, lr, &context_); } } } return true; } protected: T beta1_; T beta2_; T epsilon_; INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER); OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD); }; template class RowWiseSparseAdamOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; RowWiseSparseAdamOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), beta1_(this->template GetSingleArgument("beta1", 0.9f)), beta2_(this->template GetSingleArgument("beta2", 0.999f)), epsilon_(this->template GetSingleArgument("epsilon", 1e-5f)) {} bool RunOnDevice() override { // Enforce shapes CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_1).numel()); CAFFE_ENFORCE_EQ(Input(PARAM).sizes()[0], Input(MOMENT_2).numel()); CAFFE_ENFORCE_EQ( Input(PARAM).size_from_dim(1), Input(GRAD).size_from_dim(Input(INDICES).dim())); CAFFE_ENFORCE_EQ(Input(LR).numel(), 1); return DispatchHelper>::call( this, Input(INDICES)); } template bool DoRunWithType() { const auto* lr = Input(LR).template data(); const auto iter = OperatorBase::Input(ITER, CPU).template data()[0]; const auto t = iter + 1; const auto correction = std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t)); auto block_size = Input(PARAM).numel() / Input(PARAM).size(0); auto n = Input(GRAD).numel() / block_size; const auto* paramIn = Input(PARAM).template data(); const auto* indices = Input(INDICES).template data(); const auto* gradIn = Input(GRAD).template data(); const auto* moment1In = Input(MOMENT_1).template data(); const auto* moment2In = Input(MOMENT_2).template data(); auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data(); auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data(); auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data(); if (OutputSize() == 3) { for (auto i = 0; i < n; ++i) { auto idx = indices[i]; if (block_size == 1) { float gi = gradIn[i]; float mi = moment1Out[idx] = moment1In[idx] * beta1_ + gi * (1 - beta1_); float vi = moment2Out[idx] = moment2In[idx] * beta2_ + gi * gi * (1 - beta2_); paramOut[idx] = paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_); } else { auto offsetI = i * block_size; auto offsetIdx = idx * block_size; #ifndef NDEBUG CAFFE_ENFORCE_GE( Input(PARAM).numel(), block_size + offsetIdx, this->debug_def().input(PARAM), ", out of bound, idx:", idx, " for input i:", i, " and block size:", block_size); CAFFE_ENFORCE_GE( Input(GRAD).numel(), block_size + offsetI, this->debug_def().input(GRAD), ", out of bound idx, idx:", idx, " for input i:", i); #endif const float* w = paramIn + offsetIdx; const float* g = gradIn + offsetI; const float* m1 = moment1In + offsetIdx; const float* m2 = moment2In + idx; float* nw = paramOut + offsetIdx; float* nm1 = moment1Out + offsetIdx; float* nm2 = moment2Out + idx; float m2_sum = 0.; for (auto j = 0; j < block_size; ++j) { float gj = g[j]; m2_sum += gj * gj; } float vi = nm2[0] = m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_); for (auto j = 0; j < block_size; ++j) { float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_); nw[j] = w[j] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_); } } } } else { Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data(); for (auto i = 0; i < n; ++i) { auto idx = indices[i]; if (block_size == 1) { float gi = gradIn[i]; float mi = moment1Out[idx] = moment1In[idx] * beta1_ + gi * (1 - beta1_); float vi = moment2Out[idx] = moment2In[idx] * beta2_ + gi * gi * (1 - beta2_); float ngi = gradOut[i] = correction * mi / (std::sqrt(vi) + epsilon_); paramOut[idx] = paramIn[idx] + lr[0] * ngi; } else { auto offsetI = i * block_size; auto offsetIdx = idx * block_size; #ifndef NDEBUG CAFFE_ENFORCE_GE( Input(PARAM).numel(), block_size + offsetIdx, this->debug_def().input(PARAM), ", out of bound, idx:", idx, " for input i:", i, " and block size:", block_size); CAFFE_ENFORCE_GE( Input(GRAD).numel(), block_size + offsetI, this->debug_def().input(GRAD), ", out of bound idx, idx:", idx, " for input i:", i); #endif const float* w = paramIn + offsetIdx; const float* g = gradIn + offsetI; const float* m1 = moment1In + offsetIdx; const float* m2 = moment2In + idx; float* nw = paramOut + offsetIdx; float* nm1 = moment1Out + offsetIdx; float* nm2 = moment2Out + idx; float* ng = gradOut + offsetI; float m2_sum = 0.; for (auto j = 0; j < block_size; ++j) { float gj = g[j]; m2_sum += gj * gj; } float vi = nm2[0] = m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_); for (auto j = 0; j < block_size; ++j) { float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_); float ngi = ng[j] = correction * mi / (std::sqrt(vi) + epsilon_); nw[j] = w[j] + lr[0] * ngi; } } } } return true; } protected: T beta1_; T beta2_; T epsilon_; INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER); OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD); }; } // namespace caffe2