#ifndef CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_ #define CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_ #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" #include "caffe2/utils/math.h" namespace caffe2 { template class BisectPercentileOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit BisectPercentileOp(Args&&... args) : Operator(std::forward(args)...), pct_raw_(OperatorBase::GetRepeatedArgument( "percentile_raw", vector{})), pct_mapping_(OperatorBase::GetRepeatedArgument( "percentile_mapping", vector{})), pct_lower_(OperatorBase::GetRepeatedArgument( "percentile_lower", vector{})), pct_upper_(OperatorBase::GetRepeatedArgument( "percentile_upper", vector{})), pct_lens_( OperatorBase::GetRepeatedArgument("lengths", vector{})) { CAFFE_ENFORCE_EQ( pct_raw_.size(), pct_mapping_.size(), "Feature (raw) data and percentile value dimension should match."); CAFFE_ENFORCE_EQ( pct_raw_.size(), pct_lower_.size(), "Feature (raw) data and lower bound dimension should match."); CAFFE_ENFORCE_EQ( pct_raw_.size(), pct_upper_.size(), "Feature (raw) data and upper bound dimension should match."); n_features = pct_lens_.size(); index.reserve(n_features + 1); index[0] = 0; for (int i = 1; i <= n_features; ++i) { index[i] = index[i - 1] + pct_lens_[i - 1]; } CAFFE_ENFORCE_EQ( index[n_features], // The sum of lengths_data pct_raw_.size(), "Sum of lengths should be equal to the total number of percentile " "mapping data samples"); } bool RunOnDevice() override { // Input const auto& raw = Input(RAW); CAFFE_ENFORCE_EQ(raw.dim(), 2); const auto batch_size = raw.size(0); const auto num_features = raw.size(1); CAFFE_ENFORCE_EQ(num_features, pct_lens_.size()); const float* raw_data = raw.template data(); // Output auto* pct = Output(PCT, raw.sizes(), at::dtype()); float* pct_output = pct->template mutable_data(); // Compute percentile for each raw feature value int feature_start_index = 0; int feature_length = 0; int cur_index = 0; for (int i = 0; i < num_features; ++i) { cur_index = i; feature_start_index = index[i]; feature_length = pct_lens_[i]; for (int j = 0; j < batch_size; ++j) { pct_output[cur_index] = compute_percentile( pct_raw_.begin() + feature_start_index, pct_mapping_.begin() + feature_start_index, pct_lower_.begin() + feature_start_index, pct_upper_.begin() + feature_start_index, feature_length, raw_data[cur_index]); cur_index += num_features; } } return true; } protected: INPUT_TAGS(RAW); OUTPUT_TAGS(PCT); private: int n_features; vector pct_raw_; vector pct_mapping_; vector pct_lower_; vector pct_upper_; vector pct_lens_; vector index; vector> fast_pct; const float kEPSILON = 1e-10; int binary_search( const std::vector::iterator& data, int lo, int hi, float val) { int mid; bool low_cond, high_cond; while (lo < hi) { mid = (lo + hi) >> 1; low_cond = (data[mid] <= val); high_cond = (val < data[mid + 1]); if (low_cond && high_cond) { return mid; } else if (!low_cond) { hi = mid - 1; } else { lo = mid + 1; } } return lo; } float compute_percentile( const std::vector::iterator& pct_raw_it, const std::vector::iterator& pct_mapping_it, const std::vector::iterator& pct_lower_it, const std::vector::iterator& pct_upper_it, const int size, const float val) { // Corner cases where no interpolation is needed. if (val < pct_raw_it[0]) { return 0.; } if (val > pct_raw_it[size - 1]) { return 1.; } float result; // Interpolation by binary search const auto k = binary_search(pct_raw_it, 0, size - 1, val); if (pct_raw_it[k] == val) { // Exact match result = pct_mapping_it[k]; } else { // interpolation float w = (val - pct_raw_it[k]) / (pct_raw_it[k + 1] - pct_raw_it[k] + kEPSILON); result = (1 - w) * pct_upper_it[k] + w * pct_lower_it[k + 1]; } return result; } }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_