#ifndef CAFFE2_OPERATORS_RELU_N_OP_H_ #define CAFFE2_OPERATORS_RELU_N_OP_H_ #include #include "caffe2/operators/elementwise_ops.h" namespace caffe2 { template struct ReluNFunctor { explicit ReluNFunctor(OperatorBase& op) : n(op.GetSingleArgument("n", 6.0f)) { CAFFE_ENFORCE_GT(n, 0, "n should be greater than 0"); } template bool operator()(const int N, const T* X, T* Y, Context* context) const; const float n; }; template struct ReluNGradientFunctor { explicit ReluNGradientFunctor(OperatorBase& op) : n(op.GetSingleArgument("n", 6.0f)) { CAFFE_ENFORCE_GT(n, 0, "n should be greater than 0"); } template bool Forward( const std::vector& Y_dims, const std::vector& dY_dims, const T* Y, const T* dY, T* dX, Context* context) const; const float n; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_RELU_N_OP_H_