#pragma once
|
|
#include <vector>
|
|
#include "caffe2/core/context.h"
|
#include "caffe2/core/operator.h"
|
#include "caffe2/utils/math.h"
|
|
namespace caffe2 {
|
template <typename F, typename T, class Context>
|
class NGramFromCategoricalOp : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
template <class... Args>
|
explicit NGramFromCategoricalOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...),
|
col_ids_(this->template GetRepeatedArgument<int>("col_ids")),
|
categorical_limits_(
|
this->template GetRepeatedArgument<int>("categorical_limits")),
|
vals_(this->template GetRepeatedArgument<int>("vals")) {
|
col_num_ = col_ids_.size();
|
max_col_id_ = *std::max_element(col_ids_.begin(), col_ids_.end());
|
CAFFE_ENFORCE_EQ(col_num_, categorical_limits_.size());
|
int expected_vals_size = 0;
|
for (auto& l : categorical_limits_) {
|
CAFFE_ENFORCE_GT(l, 0);
|
expected_vals_size += l;
|
}
|
CAFFE_ENFORCE_EQ(expected_vals_size, vals_.size());
|
// compute ngram maps with small end
|
for (auto& j : col_ids_) {
|
CAFFE_ENFORCE_GE(j, 0);
|
ngram_maps_.push_back(std::map<int, int>());
|
}
|
int base = 1;
|
int idx = 0;
|
for (int k = 0; k < col_num_; k++) {
|
int l = categorical_limits_[k];
|
for (int m = 0; m < l; m++) {
|
int v = vals_[idx++];
|
ngram_maps_[k][v] = m * base;
|
}
|
base *= l;
|
}
|
}
|
|
bool RunOnDevice() override {
|
auto& floats = Input(0);
|
auto N = floats.size(0);
|
auto D = floats.size_from_dim(1);
|
const F* floats_data = floats.template data<F>();
|
|
auto* output = Output(0, {N}, at::dtype<T>());
|
auto* output_data = output->template mutable_data<T>();
|
math::Set<T, Context>(output->numel(), 0, output_data, &context_);
|
|
CAFFE_ENFORCE_GT(D, max_col_id_);
|
for (int i = 0; i < N; i++) {
|
for (int k = 0; k < col_num_; k++) {
|
int j = col_ids_[k];
|
int v = round(floats_data[i * D + j]);
|
// for out-of-vocabulary values, we always treat them the same as the
|
// first value specified in vals; if we want to mimic the behavior as
|
// sigrid NGram transform, just push front a random/impossible value at
|
// each segments of vals
|
output_data[i] += ngram_maps_[k].find(v) == ngram_maps_[k].end()
|
? 0
|
: ngram_maps_[k][v];
|
}
|
}
|
return true;
|
}
|
|
private:
|
std::vector<int> col_ids_;
|
std::vector<int> categorical_limits_;
|
std::vector<int> vals_;
|
std::vector<std::map<int, int>> ngram_maps_;
|
int col_num_;
|
int max_col_id_;
|
};
|
} // namespace caffe2
|