#ifndef CAFFE2_OPERATORS_ARG_OPS_H_ #define CAFFE2_OPERATORS_ARG_OPS_H_ #include #include #include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/core/types.h" namespace caffe2 { template class ArgOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit ArgOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(int, "axis", axis_, -1), OP_SINGLE_ARG(bool, "keepdims", keep_dims_, true) {} bool RunOnDevice() override { return DispatchHelper< TensorTypes>:: call(this, Input(0)); } template bool DoRunWithType() { const auto& X = Input(0); const int ndim = X.dim(); if (axis_ == -1) { axis_ = ndim - 1; } CAFFE_ENFORCE_GE(axis_, 0); CAFFE_ENFORCE_LT(axis_, ndim); const std::vector X_dims(X.sizes().cbegin(), X.sizes().cend()); std::vector Y_dims; Y_dims.reserve(ndim); int prev_size = 1; int next_size = 1; for (int i = 0; i < axis_; ++i) { Y_dims.push_back(X_dims[i]); prev_size *= X_dims[i]; } if (keep_dims_) { Y_dims.push_back(1); } for (int i = axis_ + 1; i < ndim; ++i) { Y_dims.push_back(X_dims[i]); next_size *= X_dims[i]; } auto* Y = Output(0, Y_dims, at::dtype()); const int n = X_dims[axis_]; return reducer_( prev_size, next_size, n, X.template data(), Y->template mutable_data(), &context_); } private: int axis_; const bool keep_dims_; Reducer reducer_{}; }; template struct ArgMaxReducer { template bool operator()( const int prev_size, const int next_size, const int n, const T* X, int64_t* Y, Context* context) const; }; template struct ArgMinReducer { template bool operator()( const int prev_size, const int next_size, const int n, const T* X, int64_t* Y, Context* context) const; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_ARG_OPS_H_