#pragma once #include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class KeySplitOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit KeySplitOp(Args&&... args) : Operator(std::forward(args)...), categorical_limit_( this->template GetSingleArgument("categorical_limit", 0)) { CAFFE_ENFORCE_GT(categorical_limit_, 0); } bool RunOnDevice() override { auto& keys = Input(0); int N = keys.numel(); const T* keys_data = keys.template data(); std::vector counts(categorical_limit_); std::vector eids(categorical_limit_); for (int k = 0; k < categorical_limit_; k++) { counts[k] = 0; } for (int i = 0; i < N; i++) { int k = keys_data[i]; CAFFE_ENFORCE_GT(categorical_limit_, k); CAFFE_ENFORCE_GE(k, 0); counts[k]++; } for (int k = 0; k < categorical_limit_; k++) { auto* eid = Output(k, {counts[k]}, at::dtype()); eids[k] = eid->template mutable_data(); counts[k] = 0; } for (int i = 0; i < N; i++) { int k = keys_data[i]; eids[k][counts[k]++] = i; } return true; } private: int categorical_limit_; }; } // namespace caffe2