#ifndef CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_ #define CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_ #include #include "caffe2/core/common_omp.h" #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/core/types.h" #include "caffe2/utils/math.h" #include #include #include namespace caffe2 { template class GatherRangesToDenseOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit GatherRangesToDenseOp(Args&&... args) : Operator(std::forward(args)...), lengths_(this->template GetRepeatedArgument("lengths")) { CAFFE_ENFORCE_GT(lengths_.size(), 0, "There has to be at least one length"); for (auto length : lengths_) { CAFFE_ENFORCE_GT(length, 0, "Each length should be positive"); } } bool RunOnDevice() override { return DispatchHelper>::call( this, this->template Input(RANGES, CPU)); } template bool DoRunWithType() { auto& data = Input(DATA); auto& ranges = Input(RANGES); CAFFE_ENFORCE_EQ(data.dim(), 1, "Data has to be 1-D"); CAFFE_ENFORCE_EQ(ranges.dim(), 3, "Ranges has to be 3-D"); if (InputSize() == 3) { auto& key = Input(KEY); CAFFE_ENFORCE_EQ(key.dim(), 1, "Key has to be 1-D"); CAFFE_ENFORCE( key.dtype().template Match(), "Key has to be type int64_t"); } CAFFE_ENFORCE_EQ( ranges.size(1), lengths_.size(), "Nummber of ranges should match number of lengths"); CAFFE_ENFORCE_EQ( ranges.size(1), OutputSize(), "Nummber of ranges should match number of outputs"); CAFFE_ENFORCE_EQ( ranges.size(2), 2, "Ranges last dimension should be of size 2"); auto* rawData = static_cast(data.raw_data()); auto* rangesData = ranges.template data(); int rangesDataOffset = 0; auto itemsize = data.dtype().itemsize(); auto batchSize = ranges.size(0); vector outputDims{batchSize, 0}; vector outputRawData; for (int i = 0; i < OutputSize(); ++i) { auto* output = Output(i); outputDims[1] = lengths_[i]; output->Resize(outputDims); char* ptr = static_cast(output->raw_mutable_data(data.dtype())); memset(ptr, 0, output->nbytes()); outputRawData.push_back(ptr); } for (int i = 0; i < batchSize; ++i) { for (int j = 0; j < OutputSize(); ++j) { auto rangeStart = rangesData[rangesDataOffset++]; auto rangeLength = rangesData[rangesDataOffset++]; if (rangeLength == 0) { // empty range, will be filled with zeros continue; } CAFFE_ENFORCE_EQ( rangeLength, lengths_[j], "Range lengths missmatch for output #", j); if (InputSize() == 2) { context_.CopyItemsSameDevice( data.dtype(), rangeLength, rawData + rangeStart * itemsize, outputRawData[j] + i * itemsize * lengths_[j]); } else { auto& key = Input(KEY); auto* key_data = key.template data(); vector> buffer; for (int b_i = 0; b_i < rangeLength; ++b_i) { int64_t one_key_item = key_data[rangeStart + b_i]; auto* one_data_item = rawData + (rangeStart + b_i) * itemsize; buffer.emplace_back(one_key_item, one_data_item); } std::sort( buffer.begin(), buffer.end(), [](const std::pair& left, const std::pair& right) { return left.first < right.first; }); for (int b_i = 0; b_i < rangeLength; ++b_i) { // Since this CPU only, directly copy to the destination. std::memcpy( outputRawData[j] + (i * lengths_[j] + b_i) * itemsize, buffer[b_i].second, itemsize); } } } } CAFFE_ENFORCE_EQ(rangesDataOffset, ranges.numel()); return true; } INPUT_TAGS(DATA, RANGES, KEY); private: vector lengths_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_