#ifndef CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_ #define CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_ #include "caffe2/core/context_gpu.h" #include "caffe2/core/cudnn_wrappers.h" #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" #include "caffe2/core/types.h" namespace caffe2 { class CuDNNActivationOpBase : public Operator { public: USE_OPERATOR_FUNCTIONS(CUDAContext); template explicit CuDNNActivationOpBase(Args&&... args) : Operator(std::forward(args)...), cudnn_wrapper_(&context_) { CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_)); CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&act_desc_)); } virtual ~CuDNNActivationOpBase() { CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_)); CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(act_desc_)); } protected: void SetTensorDescriptor( const cudnnDataType_t data_type, const int data_size) { if (data_size != input_size_) { // Since the best performance is obtained when the tesor is HW-packed, we // put X.size() to W. input_size_ = data_size; CUDNN_ENFORCE(cudnnSetTensor4dDescriptor( data_desc_, GetCudnnTensorFormat(StorageOrder::NCHW), data_type, 1, 1, 1, input_size_)); } } CuDNNWrapper cudnn_wrapper_; cudnnTensorDescriptor_t data_desc_; cudnnActivationDescriptor_t act_desc_; int input_size_ = 0; }; template class CuDNNActivationOp final : public CuDNNActivationOpBase { public: USE_OPERATOR_FUNCTIONS(CUDAContext); template explicit CuDNNActivationOp(Args&&... args) : CuDNNActivationOpBase(std::forward(args)...) { CUDNN_ENFORCE(cudnnSetActivationDescriptor( act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0)); } bool RunOnDevice() override { return DispatchHelper>::call(this, Input(0)); } template bool DoRunWithType() { const auto& X = Input(0); auto* Y = Output(0, X.sizes(), at::dtype()); if (X.numel() == 0) { Y->template mutable_data(); return true; } this->SetTensorDescriptor(cudnnTypeWrapper::type, X.numel()); CUDNN_ENFORCE(cudnnActivationForward( this->cudnn_wrapper_.inline_cudnn_handle(), this->act_desc_, cudnnTypeWrapper::kOne(), this->data_desc_, X.template data(), cudnnTypeWrapper::kZero(), this->data_desc_, Y->template mutable_data())); return true; } }; template class CuDNNActivationGradientOp final : public CuDNNActivationOpBase { public: USE_OPERATOR_FUNCTIONS(CUDAContext); template explicit CuDNNActivationGradientOp(Args&&... args) : CuDNNActivationOpBase(std::forward(args)...) { CUDNN_ENFORCE(cudnnSetActivationDescriptor( act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0)); } bool RunOnDevice() override { return DispatchHelper>::call(this, Input(0)); } template bool DoRunWithType() { const auto& Y = Input(0); const auto& dY = Input(1); auto* dX = Output(0, Y.sizes(), at::dtype()); if (Y.numel() == 0) { dX->template mutable_data(); return true; } this->SetTensorDescriptor(cudnnTypeWrapper::type, Y.numel()); CUDNN_ENFORCE(cudnnActivationBackward( this->cudnn_wrapper_.inline_cudnn_handle(), this->act_desc_, cudnnTypeWrapper::kOne(), this->data_desc_, Y.template data(), this->data_desc_, dY.template data(), this->data_desc_, Y.template data(), // Use Y_data as placeholder here. cudnnTypeWrapper::kZero(), this->data_desc_, dX->template mutable_data())); return true; } }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_