#ifndef CAFFE2_OPERATORS_FIND_OP_H_ #define CAFFE2_OPERATORS_FIND_OP_H_ #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include namespace caffe2 { template class FindOp final : public Operator { public: template explicit FindOp(Args&&... args) : Operator(std::forward(args)...), missing_value_( this->template GetSingleArgument("missing_value", -1)) {} USE_OPERATOR_CONTEXT_FUNCTIONS; USE_DISPATCH_HELPER; bool RunOnDevice() { return DispatchHelper>::call(this, Input(0)); } protected: template bool DoRunWithType() { auto& idx = Input(0); auto& needles = Input(1); auto* res_indices = Output(0, needles.sizes(), at::dtype()); const T* idx_data = idx.template data(); const T* needles_data = needles.template data(); T* res_data = res_indices->template mutable_data(); auto idx_size = idx.numel(); // Use an arbitrary cut-off for when to use brute-force // search. For larger needle sizes we first put the // index into a map if (needles.numel() < 16) { // Brute force O(nm) for (int i = 0; i < needles.numel(); i++) { T x = needles_data[i]; T res = static_cast(missing_value_); for (int j = idx_size - 1; j >= 0; j--) { if (idx_data[j] == x) { res = j; break; } } res_data[i] = res; } } else { // O(n + m) std::unordered_map idx_map; for (int j = 0; j < idx_size; j++) { idx_map[idx_data[j]] = j; } for (int i = 0; i < needles.numel(); i++) { T x = needles_data[i]; auto it = idx_map.find(x); res_data[i] = (it == idx_map.end() ? missing_value_ : it->second); } } return true; } protected: int missing_value_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_FIND_OP_H_