#ifndef CAFFE2_OPERATORS_CHANNEL_SHUFFLE_OP_H_ #define CAFFE2_OPERATORS_CHANNEL_SHUFFLE_OP_H_ #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" namespace caffe2 { template class ChannelShuffleOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit ChannelShuffleOp(Args&&... args) : Operator(std::forward(args)...), order_(StringToStorageOrder( this->template GetSingleArgument("order", "NCHW"))), OP_SINGLE_ARG(int, "group", group_, 1) { CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN); } bool RunOnDevice() override { return order_ == StorageOrder::NCHW ? RunOnDeviceWithOrderNCHW() : RunOnDeviceWithOrderNHWC(); } bool RunOnDeviceWithOrderNCHW(); bool RunOnDeviceWithOrderNHWC(); private: const StorageOrder order_; const int group_; }; template class ChannelShuffleGradientOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit ChannelShuffleGradientOp(Args&&... args) : Operator(std::forward(args)...), order_(StringToStorageOrder( this->template GetSingleArgument("order", "NCHW"))), OP_SINGLE_ARG(int, "group", group_, 1) { CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN); } bool RunOnDevice() override { return order_ == StorageOrder::NCHW ? RunOnDeviceWithOrderNCHW() : RunOnDeviceWithOrderNHWC(); } bool RunOnDeviceWithOrderNCHW(); bool RunOnDeviceWithOrderNHWC(); private: const StorageOrder order_; const int group_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_CHANNEL_SHUFFLE_OP_H_