#ifndef CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
|
#define CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
|
|
#include <algorithm>
|
#include <vector>
|
#include "caffe2/core/context.h"
|
#include "caffe2/core/operator.h"
|
#include "caffe2/utils/math.h"
|
|
namespace caffe2 {
|
|
template <class Context, bool Forward>
|
class PackRNNSequenceOpBase : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
template <class... Args>
|
explicit PackRNNSequenceOpBase(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...) {}
|
|
bool RunOnDevice() override {
|
return DispatchHelper<TensorTypes<int32_t, int64_t, float, double>>::call(
|
this, Input(0));
|
}
|
|
template <typename ValT>
|
bool DoRunWithType() {
|
// The value is copied from the sequence to the pack
|
// if Forward is true, and vice versa
|
int dim_offset = Forward ? 1 : 2;
|
auto& values = Input(0);
|
CAFFE_ENFORCE_GT(values.dim(), dim_offset);
|
|
// block_size is the size for each individual feature
|
int64_t block_size = values.size_from_dim(dim_offset);
|
auto values_vec = values.template data<ValT>();
|
|
auto& lengths = Input(LENGTHS);
|
CAFFE_ENFORCE_EQ(lengths.dim(), 1);
|
const auto cols = lengths.numel();
|
const int32_t* lengths_vec = lengths.template data<int32_t>();
|
// the total number of rows is defined as the max number from lengths
|
// if when the lengths is empty, we set rows = 0 to support zero lengths
|
const auto rows =
|
cols ? *std::max_element(lengths_vec, lengths_vec + cols) : 0;
|
CAFFE_ENFORCE_GE(rows, 0);
|
int length_sum = 0;
|
if (cols > 0) {
|
math::Sum<int, Context>(cols, lengths_vec, &length_sum, &context_);
|
}
|
|
vector<int64_t> shape;
|
// the output shape is rows * cols for the pack,
|
// or length_sum for the sequence
|
if (Forward) {
|
shape.push_back(rows);
|
shape.push_back(cols);
|
} else {
|
shape.push_back(length_sum);
|
}
|
// insert the dim for the feature
|
shape.insert(
|
shape.end(), values.sizes().begin() + dim_offset, values.sizes().end());
|
|
auto* output = Output(OUTPUTVALUE, shape, at::dtype<ValT>());
|
|
auto output_data = output->template mutable_data<ValT>();
|
// initialize output_data with zero, as it is the default value for padding
|
// when certain length is smaller than rows
|
math::Set<ValT, Context>(output->numel(), 0, output_data, &context_);
|
|
int32_t offset = 0;
|
for (int c = 0; c < cols; c++) {
|
for (int r = 0; r < lengths_vec[c]; r++) {
|
auto input_offset = Forward ? (offset + r) : (r * cols + c);
|
auto output_offset = Forward ? (r * cols + c) : (offset + r);
|
context_.CopyItemsSameDevice(
|
values.dtype(),
|
block_size,
|
values_vec + input_offset * block_size,
|
output_data + output_offset * block_size);
|
}
|
offset += lengths_vec[c];
|
}
|
return true;
|
}
|
|
private:
|
INPUT_TAGS(INPUTVALUE, LENGTHS);
|
OUTPUT_TAGS(OUTPUTVALUE);
|
};
|
} // namespace caffe2
|
|
#endif // CAFFE2_OPERATORS_PACK_RNN_SEQUENCE_OP_H_
|