#pragma once #include "caffe2/core/export_caffe2_op_to_c10.h" #include "caffe2/core/context.h" #include "caffe2/core/operator.h" C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(ResizeNearest); namespace caffe2 { template class ResizeNearestOp final : public Operator { public: template explicit ResizeNearestOp(Args&&... args) : Operator(std::forward(args)...), width_scale_(1), height_scale_(1), order_(StringToStorageOrder( this->template GetSingleArgument("order", "NCHW"))) { if (HasArgument("width_scale")) { width_scale_ = static_cast( this->template GetSingleArgument("width_scale", 1)); } if (HasArgument("height_scale")) { height_scale_ = static_cast( this->template GetSingleArgument("height_scale", 1)); } CAFFE_ENFORCE_GT(width_scale_, 0); CAFFE_ENFORCE_GT(height_scale_, 0); CAFFE_ENFORCE(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC); } USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; bool RunOnDeviceWithOrderNCHW(); bool RunOnDeviceWithOrderNHWC(); protected: T width_scale_; T height_scale_; StorageOrder order_; }; template class ResizeNearestGradientOp final : public Operator { public: template explicit ResizeNearestGradientOp(Args&&... args) : Operator(std::forward(args)...), width_scale_(1), height_scale_(1), order_(StringToStorageOrder( this->template GetSingleArgument("order", "NCHW"))) { width_scale_ = static_cast( this->template GetSingleArgument("width_scale", 1)); height_scale_ = static_cast( this->template GetSingleArgument("height_scale", 1)); CAFFE_ENFORCE_GT(width_scale_, 0); CAFFE_ENFORCE_GT(height_scale_, 0); CAFFE_ENFORCE(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC); } USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; bool RunOnDeviceWithOrderNCHW(); bool RunOnDeviceWithOrderNHWC(); protected: T width_scale_; T height_scale_; StorageOrder order_; }; } // namespace caffe2