#ifndef CAFFE2_OPERATORS_INDEX_HASH_OPS_H_ #define CAFFE2_OPERATORS_INDEX_HASH_OPS_H_ #include "caffe2/core/asan.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" namespace caffe2 { template class IndexHashOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit IndexHashOp(Args&&... args) : Operator(std::forward(args)...), seed_(this->template GetSingleArgument("seed", 0)), modulo_(this->template GetSingleArgument("modulo", 0)) { CAFFE_ENFORCE_GT(modulo_, 0, "MODULO should be > 0"); } bool RunOnDevice() override { return DispatchHelper>::call( this, Input(INDICES)); } template bool DoRunWithType() { auto& indices = Input(INDICES); auto* hashed_indices = Output(HASHED_INDICES, indices.sizes(), at::dtype()); CAFFE_ENFORCE_GE( static_cast(std::numeric_limits::max()), modulo_, "MODULO shouldn't be larger than the numeric limit of the indices"); auto N = indices.numel(); auto* indices_data = indices.template data(); auto* hashed_indices_data = hashed_indices->template mutable_data(); for (auto i = 0; i < N; i++) { hashed_indices_data[i] = hash(indices_data[i]); } return true; } protected: template CAFFE2_NO_SANITIZE("signed-integer-overflow") T hash(T id) { int8_t* bytes = (int8_t*)&id; T hashed = seed_ * 0xDEADBEEF; for (int i = 0; i < sizeof(T) / sizeof(int8_t); i++) { hashed = hashed * 65537 + bytes[i]; } // We want the result of the modulo to be positive. This works under the // assumption that modulo_ > 0 which is enforced in the constructor. auto modHashed = hashed % modulo_; return modHashed >= 0 ? modHashed : modHashed + modulo_; } private: INPUT_TAGS(INDICES); OUTPUT_TAGS(HASHED_INDICES); int64_t seed_; int64_t modulo_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_INDEX_HASH_OPS_H_