|
#pragma once
|
|
#include "caffe2/core/context.h"
|
#include "caffe2/core/operator.h"
|
|
namespace caffe2 {
|
|
// RecordShapeOp records the shape of the input tensor to a vector of int. You
|
// mostly don't need this operator explicitly, and it is mostly used in the
|
// autodiff process.
|
template <class Context>
|
class ShapeOp : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
template <class... Args>
|
explicit ShapeOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...),
|
axes_(OperatorBase ::GetRepeatedArgument<int>("axes")) {}
|
|
bool RunOnDevice() override {
|
auto& data = Input(DATA);
|
|
int numDims = data.dim();
|
int numAxes = axes_.size();
|
if (numAxes == 0) {
|
auto* output = Output(0, {numDims}, at::dtype<int64_t>());
|
int64_t* output_data = output->template mutable_data<int64_t>();
|
context_.CopyBytesSameDevice(
|
numDims * sizeof(int64_t), data.sizes().data(), output_data);
|
return true;
|
}
|
|
auto* output = Output(0, {numAxes}, at::dtype<int64_t>());
|
auto src = reinterpret_cast<const char*>(data.sizes().data());
|
auto out = reinterpret_cast<char*>(output->template mutable_data<int64_t>());
|
for (int i = 0; i < numAxes; i++) {
|
auto axis = axes_[i];
|
CAFFE_ENFORCE_LT(axis, numDims, "Axis out of range");
|
CAFFE_ENFORCE_GE(axis, 0, "Each axis should be non-negative");
|
context_.CopyBytesSameDevice(
|
sizeof(int64_t), src + axis * sizeof(int64_t), out);
|
out += sizeof(int64_t);
|
}
|
return true;
|
}
|
|
INPUT_TAGS(DATA);
|
|
private:
|
vector<int> axes_;
|
};
|
|
} // namespace caffe2
|