#ifndef CAFFE2_OPERATORS_REDUCE_OPS_H_ #define CAFFE2_OPERATORS_REDUCE_OPS_H_ #include #include #include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/core/types.h" #include "caffe2/utils/math.h" namespace caffe2 { template class ReduceOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit ReduceOp(Args&&... args) : Operator(std::forward(args)...), axes_(this->template GetRepeatedArgument("axes")), OP_SINGLE_ARG(bool, "keepdims", keep_dims_, true) {} bool RunOnDevice() override { return DispatchHelper::call(this, Input(0)); } template bool DoRunWithType() { const auto& X = Input(0); const int ndim = X.dim(); const std::vector X_dims(X.sizes().cbegin(), X.sizes().cend()); if (axes_.empty()) { axes_.resize(ndim); std::iota(axes_.begin(), axes_.end(), 0); } else { for (auto& axis : axes_) { axis = X.canonical_axis_index(axis); } std::sort(axes_.begin(), axes_.end()); CAFFE_ENFORCE_GE(axes_.front(), 0, "Axes ids must be non-negative."); CAFFE_ENFORCE_LT( axes_.back(), ndim, "Axes ids must be smaller than the dimensions of input."); } std::vector output_dims; output_dims.reserve(ndim); std::size_t cur_axis = 0; for (int i = 0; i < ndim; ++i) { if (cur_axis < axes_.size() && i == axes_[cur_axis]) { if (keep_dims_) { output_dims.push_back(1); } ++cur_axis; } else { output_dims.push_back(X_dims[i]); } } auto* Y = Output(0, output_dims, at::dtype()); std::vector Y_dims = X_dims; for (const int axis : axes_) { Y_dims[axis] = 1; } return reducer_.template Forward( X_dims, Y_dims, X.template data(), Y->template mutable_data(), &context_); } private: std::vector axes_; const int keep_dims_; const Reducer reducer_{}; }; template class ReduceGradientOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit ReduceGradientOp(Args&&... args) : Operator(std::forward(args)...), axes_(this->template GetRepeatedArgument("axes")) {} bool RunOnDevice() override { return DispatchHelper::call(this, Input(0)); } template bool DoRunWithType() { const auto& dY = Input(0); const auto& X = Input(1); const auto& Y = Input(2); const int ndim = X.dim(); if (axes_.empty()) { axes_.resize(ndim); std::iota(axes_.begin(), axes_.end(), 0); } else { for (auto& axis : axes_) { axis = X.canonical_axis_index(axis); } std::sort(axes_.begin(), axes_.end()); CAFFE_ENFORCE_GE(axes_.front(), 0, "Axes ids must be non-negative."); CAFFE_ENFORCE_LT( axes_.back(), ndim, "Axes ids must be smaller than the dimensions of input."); } const std::vector dX_dims(X.sizes().cbegin(), X.sizes().cend()); std::vector dY_dims = dX_dims; for (const int axis : axes_) { dY_dims[axis] = 1; } auto* dX = Output(0, X.sizes(), at::dtype()); return reducer_.template Backward( dY_dims, dX_dims, dY.template data(), X.template data(), Y.template data(), dX->template mutable_data(), &context_); } private: std::vector axes_; const Reducer reducer_{}; }; template struct MinReducer { template bool Forward( const std::vector& X_dims, const std::vector& Y_dims, const T* X_data, T* Y_data, Context* context) const { math::ReduceMin( X_dims.size(), X_dims.data(), Y_dims.data(), T(1), X_data, Y_data, context); return true; } template bool Backward( const std::vector& dY_dims, const std::vector& dX_dims, const T* dY_data, const T* X_data, const T* Y_data, T* dX_data, Context* context) const; }; template struct MaxReducer { template bool Forward( const std::vector& X_dims, const std::vector& Y_dims, const T* X_data, T* Y_data, Context* context) const { math::ReduceMax( X_dims.size(), X_dims.data(), Y_dims.data(), T(1), X_data, Y_data, context); return true; } template bool Backward( const std::vector& dY_dims, const std::vector& dX_dims, const T* dY_data, const T* X_data, const T* Y_data, T* dX_data, Context* context) const; }; template struct SumReducer { template bool Forward( const std::vector& X_dims, const std::vector& Y_dims, const T* X_data, T* Y_data, Context* context) const { math::ReduceSum( X_dims.size(), X_dims.data(), Y_dims.data(), T(1), X_data, Y_data, context); return true; } template bool Backward( const std::vector& dY_dims, const std::vector& dX_dims, const T* dY_data, const T* /* X_data */, const T* /* Y_data */, T* dX_data, Context* context) const { math::Broadcast( dY_dims.size(), dY_dims.data(), dX_dims.size(), dX_dims.data(), T(1), dY_data, dX_data, context); return true; } }; template struct MeanReducer { template bool Forward( const std::vector& X_dims, const std::vector& Y_dims, const T* X_data, T* Y_data, Context* context) const { math::ReduceMean( X_dims.size(), X_dims.data(), Y_dims.data(), T(1), X_data, Y_data, context); return true; } template bool Backward( const std::vector& dY_dims, const std::vector& dX_dims, const T* dY_data, const T* /* X_data */, const T* /* Y_data */, T* dX_data, Context* context) const { const int dY_size = std::accumulate( dY_dims.cbegin(), dY_dims.cend(), 1, std::multiplies()); const int dX_size = std::accumulate( dX_dims.cbegin(), dX_dims.cend(), 1, std::multiplies()); math::Broadcast( dY_dims.size(), dY_dims.data(), dX_dims.size(), dX_dims.data(), static_cast(dY_size) / static_cast(dX_size), dY_data, dX_data, context); return true; } }; template struct L1Reducer { template bool Forward( const std::vector& X_dims, const std::vector& Y_dims, const T* X_data, T* Y_data, Context* context) const { math::ReduceL1( X_dims.size(), X_dims.data(), Y_dims.data(), T(1), X_data, Y_data, context); return true; } template bool Backward( const std::vector& dY_dims, const std::vector& dX_dims, const T* dY_data, const T* X_data, const T* Y_data, T* dX_data, Context* context) const; }; template struct L2Reducer { template bool Forward( const std::vector& X_dims, const std::vector& Y_dims, const T* X_data, T* Y_data, Context* context) const { math::ReduceL2( X_dims.size(), X_dims.data(), Y_dims.data(), T(1), X_data, Y_data, context); return true; } template bool Backward( const std::vector& dY_dims, const std::vector& dX_dims, const T* dY_data, const T* X_data, const T* Y_data, T* dX_data, Context* context) const; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_REDUCE_OPS_H_