// Copyright 2004-present Facebook. All Rights Reserved. #ifndef BOX_WITH_NMS_AND_LIMIT_OP_H_ #define BOX_WITH_NMS_AND_LIMIT_OP_H_ #include "caffe2/core/export_caffe2_op_to_c10.h" #include "caffe2/core/context.h" #include "caffe2/core/operator.h" C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(BoxWithNMSLimit) namespace caffe2 { // C++ implementation of function insert_box_results_with_nms_and_limit() template class BoxWithNMSLimitOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit BoxWithNMSLimitOp(Args&&... args) : Operator(std::forward(args)...), score_thres_( this->template GetSingleArgument("score_thresh", 0.05)), nms_thres_(this->template GetSingleArgument("nms", 0.3)), detections_per_im_( this->template GetSingleArgument("detections_per_im", 100)), soft_nms_enabled_( this->template GetSingleArgument("soft_nms_enabled", false)), soft_nms_method_str_(this->template GetSingleArgument( "soft_nms_method", "linear")), soft_nms_sigma_( this->template GetSingleArgument("soft_nms_sigma", 0.5)), soft_nms_min_score_thres_(this->template GetSingleArgument( "soft_nms_min_score_thres", 0.001)), rotated_(this->template GetSingleArgument("rotated", false)), cls_agnostic_bbox_reg_(this->template GetSingleArgument( "cls_agnostic_bbox_reg", false)), input_boxes_include_bg_cls_(this->template GetSingleArgument( "input_boxes_include_bg_cls", true)), output_classes_include_bg_cls_(this->template GetSingleArgument( "output_classes_include_bg_cls", true)), legacy_plus_one_( this->template GetSingleArgument("legacy_plus_one", true)) { CAFFE_ENFORCE( soft_nms_method_str_ == "linear" || soft_nms_method_str_ == "gaussian", "Unexpected soft_nms_method"); soft_nms_method_ = (soft_nms_method_str_ == "linear") ? 1 : 2; // When input `boxes` doesn't inlcude background class, the score will skip // background class and start with foreground classes directly, and put the // background class in the end, i.e. score[:, 0:NUM_CLASSES-1] represents // foreground classes and score[:,NUM_CLASSES] represents background class. input_scores_fg_cls_starting_id_ = (int)input_boxes_include_bg_cls_; } ~BoxWithNMSLimitOp() {} bool RunOnDevice() override; protected: // TEST.SCORE_THRESH float score_thres_ = 0.05; // TEST.NMS float nms_thres_ = 0.3; // TEST.DETECTIONS_PER_IM int detections_per_im_ = 100; // TEST.SOFT_NMS.ENABLED bool soft_nms_enabled_ = false; // TEST.SOFT_NMS.METHOD std::string soft_nms_method_str_ = "linear"; unsigned int soft_nms_method_ = 1; // linear // TEST.SOFT_NMS.SIGMA float soft_nms_sigma_ = 0.5; // Lower-bound on updated scores to discard boxes float soft_nms_min_score_thres_ = 0.001; // Set for RRPN case to handle rotated boxes. Inputs should be in format // [ctr_x, ctr_y, width, height, angle (in degrees)]. bool rotated_{false}; // MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG bool cls_agnostic_bbox_reg_{false}; // Whether input `boxes` includes background class. If true, boxes will have // shape of (N, (num_fg_class+1) * 4or5), otherwise (N, num_fg_class * 4or5) bool input_boxes_include_bg_cls_{true}; // Whether output `classes` includes background class. If true, index 0 will // represent background, and valid outputs start from 1. bool output_classes_include_bg_cls_{true}; // The index where foreground starts in scoures. Eg. if 0 represents // background class then foreground class starts with 1. int input_scores_fg_cls_starting_id_{1}; // The infamous "+ 1" for box width and height dating back to the DPM days bool legacy_plus_one_{true}; // Map a class id (starting with background and then foreground) from (0, 1, // ..., NUM_FG_CLASSES) to it's matching value in box inline int get_box_cls_index(int bg_fg_cls_id) { if (cls_agnostic_bbox_reg_) { return 0; } else if (!input_boxes_include_bg_cls_) { return bg_fg_cls_id - 1; } else { return bg_fg_cls_id; } } // Map a class id (starting with background and then foreground) from (0, 1, // ..., NUM_FG_CLASSES) to it's matching value in score inline int get_score_cls_index(int bg_fg_cls_id) { return bg_fg_cls_id - 1 + input_scores_fg_cls_starting_id_; } }; } // namespace caffe2 #endif // BOX_WITH_NMS_AND_LIMIT_OP_H_