#pragma once #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class CAFFE2_API SparseNormalizeOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit SparseNormalizeOp(Args&&... args) : Operator(std::forward(args)...), use_max_norm_( this->template GetSingleArgument("use_max_norm", true)), norm_(this->template GetSingleArgument("norm", 1.0)) { CAFFE_ENFORCE_GE(norm_, 0, "norm should be bigger than 0"); } bool RunOnDevice() override; template bool DoRunWithType(); protected: bool use_max_norm_; float norm_; INPUT_TAGS(PARAM, INDICES); OUTPUT_TAGS(OUTPUT_PARAM); }; } // namespace caffe2