#ifndef CAFFE2_OPERATORS_ARG_OPS_H_
|
#define CAFFE2_OPERATORS_ARG_OPS_H_
|
|
#include <algorithm>
|
#include <iterator>
|
#include <vector>
|
|
#include "caffe2/core/context.h"
|
#include "caffe2/core/operator.h"
|
#include "caffe2/core/types.h"
|
|
namespace caffe2 {
|
|
template <class Context, class Reducer>
|
class ArgOp final : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
template <class... Args>
|
explicit ArgOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...),
|
OP_SINGLE_ARG(int, "axis", axis_, -1),
|
OP_SINGLE_ARG(bool, "keepdims", keep_dims_, true) {}
|
|
bool RunOnDevice() override {
|
return DispatchHelper<
|
TensorTypes<std::int32_t, std::int64_t, float, double>>::
|
call(this, Input(0));
|
}
|
|
template <typename T>
|
bool DoRunWithType() {
|
const auto& X = Input(0);
|
|
const int ndim = X.dim();
|
if (axis_ == -1) {
|
axis_ = ndim - 1;
|
}
|
CAFFE_ENFORCE_GE(axis_, 0);
|
CAFFE_ENFORCE_LT(axis_, ndim);
|
const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
|
std::vector<int64_t> Y_dims;
|
Y_dims.reserve(ndim);
|
int prev_size = 1;
|
int next_size = 1;
|
for (int i = 0; i < axis_; ++i) {
|
Y_dims.push_back(X_dims[i]);
|
prev_size *= X_dims[i];
|
}
|
if (keep_dims_) {
|
Y_dims.push_back(1);
|
}
|
for (int i = axis_ + 1; i < ndim; ++i) {
|
Y_dims.push_back(X_dims[i]);
|
next_size *= X_dims[i];
|
}
|
auto* Y = Output(0, Y_dims, at::dtype<int64_t>());
|
const int n = X_dims[axis_];
|
return reducer_(
|
prev_size,
|
next_size,
|
n,
|
X.template data<T>(),
|
Y->template mutable_data<int64_t>(),
|
&context_);
|
}
|
|
private:
|
int axis_;
|
const bool keep_dims_;
|
Reducer reducer_{};
|
};
|
|
template <class Context>
|
struct ArgMaxReducer {
|
template <typename T>
|
bool operator()(
|
const int prev_size,
|
const int next_size,
|
const int n,
|
const T* X,
|
int64_t* Y,
|
Context* context) const;
|
};
|
|
template <class Context>
|
struct ArgMinReducer {
|
template <typename T>
|
bool operator()(
|
const int prev_size,
|
const int next_size,
|
const int n,
|
const T* X,
|
int64_t* Y,
|
Context* context) const;
|
};
|
|
} // namespace caffe2
|
|
#endif // CAFFE2_OPERATORS_ARG_OPS_H_
|