#pragma once #include "caffe2/core/operator.h" namespace caffe2 { template struct FtrlParams { explicit FtrlParams(OperatorBase* op) : alphaInv(1.0 / op->GetSingleArgument("alpha", 0.005f)), beta(op->GetSingleArgument("beta", 1.0f)), lambda1(op->GetSingleArgument("lambda1", 0.001f)), lambda2(op->GetSingleArgument("lambda2", 0.001f)) {} T alphaInv; T beta; T lambda1; T lambda2; }; // TODO(dzhulgakov): implement GPU version if necessary template class FtrlOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; FtrlOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), params_(this) { CAFFE_ENFORCE( !HasArgument("alpha") || ALPHA >= InputSize(), "Cannot specify alpha by both input and argument"); } bool RunOnDevice() override; protected: FtrlParams params_; INPUT_TAGS(VAR, N_Z, GRAD, ALPHA); OUTPUT_TAGS(OUTPUT_VAR, OUTPUT_N_Z); }; template class SparseFtrlOp final : public Operator { public: SparseFtrlOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), params_(this) { CAFFE_ENFORCE( !HasArgument("alpha") || ALPHA >= InputSize(), "Cannot specify alpha by both input and argument"); } bool RunOnDevice() override { // run time learning rate override if (ALPHA < InputSize()) { CAFFE_ENFORCE_EQ(Input(ALPHA).numel(), 1, "alpha should be real-valued"); params_.alphaInv = 1.0 / *(Input(ALPHA).template data()); } // Use run-time polymorphism auto& indices = Input(INDICES); if (indices.template IsType()) { DoRun(); } else if (indices.template IsType()) { DoRun(); } else { LOG(FATAL) << "Unsupported type of INDICES in SparseFtrlOp: " << indices.dtype().name(); } return true; } protected: FtrlParams params_; INPUT_TAGS(VAR, N_Z, INDICES, GRAD, ALPHA); OUTPUT_TAGS(OUTPUT_VAR, OUTPUT_N_Z); private: template void DoRun(); }; }