#ifndef CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_ #define CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_ #include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" #include "caffe2/utils/math.h" namespace caffe2 { template class ClipTensorByScalingOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; ClipTensorByScalingOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws) { threshold_ = this->template GetSingleArgument("threshold", 0.0); CAFFE_ENFORCE_GT(threshold_, 0, "Threshold must be greater than 0"); } bool RunOnDevice() override { const auto& input_tensor = Input(0); CAFFE_ENFORCE_GT(input_tensor.numel(), 0); const auto& val = Input(1); CAFFE_ENFORCE_EQ(val.numel(), 1); const auto* input_tensor_data = input_tensor.template data(); const auto* val_data = val.template data(); auto* clipped = Output(0, input_tensor.sizes(), at::dtype()); float* clipped_tensor_data = clipped->template mutable_data(); if (InputSize() > 2) { const auto& additional_threshold = Input(2); CAFFE_ENFORCE_EQ(additional_threshold.numel(), 1); threshold_ *= *(additional_threshold.template data()); } if (*val_data > threshold_) { float ratio = threshold_ / *val_data; math::Scale( clipped->numel(), ratio, input_tensor_data, clipped_tensor_data, &context_); } else { if (input_tensor_data != clipped_tensor_data) { clipped->CopyFrom(input_tensor, /*async*/ true); } } return true; } private: float threshold_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_