#ifndef CAFFE2_OPERATORS_CLIP_OP_H_ #define CAFFE2_OPERATORS_CLIP_OP_H_ #include #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class ClipOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit ClipOp(Args&&... args) : Operator(std::forward(args)...), min_(std::numeric_limits::lowest()), max_(std::numeric_limits::max()) { if (HasArgument("min")) { min_ = static_cast(this->template GetSingleArgument("min", 0)); } if (HasArgument("max")) { max_ = static_cast(this->template GetSingleArgument("max", 0)); } } bool RunOnDevice() override; protected: T min_; T max_; }; template class ClipGradientOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit ClipGradientOp(Args&&... args) : Operator(std::forward(args)...), min_(std::numeric_limits::lowest()), max_(std::numeric_limits::max()) { if (HasArgument("min")) { min_ = static_cast(this->template GetSingleArgument("min", 0)); } if (HasArgument("max")) { max_ = static_cast(this->template GetSingleArgument("max", 0)); } } bool RunOnDevice() override; protected: T min_; T max_; // Input: Y, dY; Output: dX }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_CLIP_OP_H_