#ifndef CAFFE2_OPERATORS_SWISH_OP_H_ #define CAFFE2_OPERATORS_SWISH_OP_H_ #include "caffe2/operators/elementwise_ops.h" #include "caffe2/utils/math.h" namespace caffe2 { template struct SwishFunctor { template bool operator()(const int N, const T* X, T* Y, Context* context) const; }; template class SwishGradientOp final : public Operator { public: USE_SIMPLE_CTOR_DTOR(SwishGradientOp) USE_OPERATOR_CONTEXT_FUNCTIONS; template bool DoRunWithType(); bool RunOnDevice() override { return DispatchHelper>::call(this, Input(X)); } protected: INPUT_TAGS(X, Y, DY); OUTPUT_TAGS(DX); }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_SWISH_OP_H_