#ifndef CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_ #define CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_ #include #include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" namespace caffe2 { template class MergeIdListsOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_SIMPLE_CTOR_DTOR(MergeIdListsOp); template bool DoRunWithType() { auto& first_lengths = Input(0); CAFFE_ENFORCE_EQ(first_lengths.dim(), 1, "LENGTHS should be 1-D"); const auto batch_size = first_lengths.numel(); auto* out_lengths = Output(0, first_lengths.sizes(), at::dtype()); auto* out_lengths_data = out_lengths->template mutable_data(); /** * Loop to figure out how much space to reserve for output * and perform checks. */ auto M = 0; for (size_t i = 0; i < InputSize(); i += 2) { auto& lengths = Input(i); CAFFE_ENFORCE_EQ(lengths.dim(), 1, "LENGTHS should be 1-D"); CAFFE_ENFORCE_EQ(lengths.numel(), batch_size, "LENGTHS should be equal"); auto& values = Input(i + 1); CAFFE_ENFORCE_EQ(values.dim(), 1, "VALUES should be 1-D"); M += values.numel(); } auto* out_values = Output(1, {M}, at::dtype()); T* out_values_data = out_values->template mutable_data(); auto pos = 0; // TODO(badri): Use unordered_set if performance is an issue std::set deduped; std::vector offsets(InputSize(), 0); for (auto sample = 0; sample < batch_size; sample++) { for (size_t i = 0; i < InputSize(); i += 2) { auto& lengths = Input(i); const auto* lengths_data = lengths.template data(); auto& values = Input(i + 1); const T* values_data = values.template data(); const auto length = lengths_data[sample]; for (auto j = offsets[i]; j < offsets[i] + length; j++) { deduped.insert(values_data[j]); } offsets[i] += length; } for (auto val : deduped) { out_values_data[pos++] = val; } out_lengths_data[sample] = deduped.size(); deduped.clear(); } out_values->Resize(pos); return true; } bool RunOnDevice() override { return DispatchHelper>::call(this, Input(1)); } }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_