#ifndef CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_ #define CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_ #include "caffe2/core/context.h" #include "caffe2/core/export_caffe2_op_to_c10.h" #include "caffe2/core/operator.h" C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(PiecewiseLinearTransform); namespace caffe2 { template class PiecewiseLinearTransformOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit PiecewiseLinearTransformOp(Args&&... args) : Operator(std::forward(args)...) { binary_ = this->template GetSingleArgument("binary", false); // Retrieve transform params (i.e., the linear functions). bounds_from_arg_ = this->template GetRepeatedArgument("bounds"); slopes_from_arg_ = this->template GetRepeatedArgument("slopes"); intercepts_from_arg_ = this->template GetRepeatedArgument("intercepts"); transform_param_from_arg_ = CheckTransParamFromArg(); } bool RunOnDevice() override { return binary_ ? TransformBinary() : TransformGeneral(); } private: // num_func_per_group is the number of pieces of linear functions of // each group. // num_group: The number of groups of linear functions. Each group is for // transforming one column of predictions. void InferNumFunctionsPerGroup( const int64_t num_bounds, const int64_t num_slopes, const int64_t num_intercepts, int64_t* num_func_per_group, int64_t* num_group) { CAFFE_ENFORCE_EQ(num_slopes, num_intercepts); // This is based on the facts: // 1. in each group, the num of bounds minus the num of slopes is 1; // 2. each group has the same number of pieces. *num_group = num_bounds - num_slopes; CAFFE_ENFORCE_GT(*num_group, 0); if (binary_) { CAFFE_ENFORCE_EQ(*num_group, 1); } *num_func_per_group = num_slopes / *num_group; CAFFE_ENFORCE_GT(*num_func_per_group, 0); CAFFE_ENFORCE_EQ(num_slopes % *num_group, 0); } bool CheckBoundsSorted( const T* bounds, const int64_t num_bounds_per_group, const int64_t num_group) { const T* start = bounds; for (int64_t i = 0; i < num_group; i++) { if (!std::is_sorted(start, start + num_bounds_per_group)) { return false; } start += num_bounds_per_group; } return true; } // Returns true if the transform params from arg are valid. // Otherwise, we will assume the transform params will pass from Input blobs. bool CheckTransParamFromArg() { int good_param = 0; good_param += bounds_from_arg_.size() > 0; good_param += slopes_from_arg_.size() > 0; good_param += intercepts_from_arg_.size() > 0; CAFFE_ENFORCE( good_param == 0 || good_param == 3, "bounds, slopes, intercepts must be all set or all not set"); if (good_param == 3) { int64_t num_func_per_group; int64_t num_group; InferNumFunctionsPerGroup( bounds_from_arg_.size(), slopes_from_arg_.size(), intercepts_from_arg_.size(), &num_func_per_group, &num_group); CAFFE_ENFORCE( CheckBoundsSorted( bounds_from_arg_.data(), num_func_per_group + 1, num_group), "bounds must be sorted for each group"); } return good_param == 3; } void setUpTensors(int64_t& num_func_per_group, int64_t& num_group, int64_t M); void GetTransParamData( const T** bounds, const T** slopes, const T** intercepts, int64_t* num_func_per_group, int64_t* num_group) { int64_t num_bounds; int64_t num_slopes; int64_t num_intercepts; if (transform_param_from_arg_) { CAFFE_ENFORCE_EQ(InputSize(), 1); *bounds = bounds_from_arg_.data(); *slopes = slopes_from_arg_.data(); *intercepts = intercepts_from_arg_.data(); num_bounds = bounds_from_arg_.size(); num_slopes = slopes_from_arg_.size(); num_intercepts = intercepts_from_arg_.size(); } else { CAFFE_ENFORCE_EQ(InputSize(), 4); auto& bounds_input = Input(BOUNDS); auto& slopes_input = Input(SLOPES); auto& intercepts_input = Input(INTERCEPTS); *bounds = bounds_input.template data(); *slopes = slopes_input.template data(); *intercepts = intercepts_input.template data(); num_bounds = bounds_input.numel(); num_slopes = slopes_input.numel(); num_intercepts = intercepts_input.numel(); } InferNumFunctionsPerGroup( num_bounds, num_slopes, num_intercepts, num_func_per_group, num_group); } bool TransformGeneral() { auto& X = Input(0); CAFFE_ENFORCE_EQ(X.dim(), 2); int64_t N = X.dim32(0); int64_t M = X.dim32(1); auto* Y = Output(0, X.sizes(), at::dtype()); const auto* Xdata = X.template data(); T* Ydata = Y->template mutable_data(); const T* bounds; const T* slopes; const T* intercepts; int64_t num_func_per_group; int64_t num_group; GetTransParamData( &bounds, &slopes, &intercepts, &num_func_per_group, &num_group); CAFFE_ENFORCE_EQ(num_group, M); for (int64_t j = 0; j < M; ++j) { const T* bounds_group = bounds + j * (num_func_per_group + 1); const T* slopes_group = slopes + j * num_func_per_group; const T* intercepts_group = intercepts + j * num_func_per_group; for (int64_t i = 0; i < N; ++i) { Ydata[i * M + j] = PiecewiseLinearTransform( Xdata[i * M + j], bounds_group, slopes_group, intercepts_group, num_func_per_group); } } return true; } bool TransformBinary() { auto& X = Input(PREDICTIONS); CAFFE_ENFORCE(X.dim() == 1 || X.dim() == 2); int64_t N = X.dim32(0); int64_t M = X.dim() == 2 ? X.dim32(1) : 1; CAFFE_ENFORCE( M == 1 || M == 2, "If binary is set to true, the input must be Nx2 or Nx1 tensor"); auto* Y = Output(0, X.sizes(), at::dtype()); const auto* Xdata = X.template data(); T* Ydata = Y->template mutable_data(); const T* bounds; const T* slopes; const T* intercepts; int64_t num_func_per_group; int64_t num_group; GetTransParamData( &bounds, &slopes, &intercepts, &num_func_per_group, &num_group); CAFFE_ENFORCE_EQ(num_group, 1); if (M == 1) { for (int64_t i = 0; i < N; ++i) { Ydata[i] = PiecewiseLinearTransform( Xdata[i], bounds, slopes, intercepts, num_func_per_group); } } else { for (int64_t i = 0; i < N; ++i) { Ydata[i * M + 1] = PiecewiseLinearTransform( Xdata[i * M + 1], bounds, slopes, intercepts, num_func_per_group); Ydata[i * M] = 1.0f - Ydata[i * M + 1]; } } return true; } T PiecewiseLinearTransform( const T x, const T* bounds, const T* slopes, const T* intercepts, const int64_t num_func_per_group) { T y = 0; // deal with samples out of bounds // make it the same as the upper/lower bound value if (x <= bounds[0]) { y = slopes[0] * bounds[0] + intercepts[0]; } else if (x >= bounds[num_func_per_group]) { y = slopes[num_func_per_group - 1] * bounds[num_func_per_group] + intercepts[num_func_per_group - 1]; } else { auto low_bound = std::lower_bound(bounds, bounds + num_func_per_group + 1, x); int bounds_idx = low_bound - bounds - 1; // compute the piecewise linear transformation as Y y = slopes[bounds_idx] * x + intercepts[bounds_idx]; } return y; } private: bool binary_; vector bounds_from_arg_; vector slopes_from_arg_; vector intercepts_from_arg_; Tensor bounds_device_{Context::GetDeviceType()}; Tensor intercepts_device_{Context::GetDeviceType()}; Tensor slopes_device_{Context::GetDeviceType()}; bool gpu_copied_ = false; // If true, the piecewise linear functions are passed through args, // otherwise, they are passed through Input blobs. bool transform_param_from_arg_; INPUT_TAGS(PREDICTIONS, BOUNDS, SLOPES, INTERCEPTS); }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_