#ifndef CAFFE2_OPERATORS_MOMENTS_OP_H_ #define CAFFE2_OPERATORS_MOMENTS_OP_H_ #include #include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class MomentsOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit MomentsOp(Args&&... args) : Operator(std::forward(args)...), axes_(this->template GetRepeatedArgument("axes")), OP_SINGLE_ARG(bool, "keepdims", keep_dims_, true) {} bool RunOnDevice() override { const auto& X = Input(0); const int ndim = X.dim(); if (axes_.empty()) { axes_.resize(ndim); std::iota(axes_.begin(), axes_.end(), 0); } else { std::sort(axes_.begin(), axes_.end()); CAFFE_ENFORCE_GE(axes_.front(), 0, "Axes ids must be non-negative."); CAFFE_ENFORCE_LT( axes_.back(), ndim, "Axes ids must be smaller than the dimensions of input."); } const std::vector X_dims(X.sizes().cbegin(), X.sizes().cend()); std::vector Y_dims = X_dims; for (const int axis : axes_) { Y_dims[axis] = 1; } std::vector output_dims; output_dims.reserve(ndim); std::size_t cur_axis = 0; for (int i = 0; i < ndim; ++i) { if (cur_axis < axes_.size() && i == axes_[cur_axis]) { if (keep_dims_) { output_dims.push_back(1); } ++cur_axis; } else { output_dims.push_back(X_dims[i]); } } auto* mean = Output(0, output_dims, at::dtype()); auto* var = Output(1, output_dims, at::dtype()); math::Moments( X_dims.size(), X_dims.data(), Y_dims.data(), X.template data(), mean->template mutable_data(), var->template mutable_data(), &context_); return true; } private: std::vector axes_; const int keep_dims_; }; template class MomentsGradientOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit MomentsGradientOp(Args&&... args) : Operator(std::forward(args)...), axes_(this->template GetRepeatedArgument("axes")) {} bool RunOnDevice() override { const auto& dmean = Input(0); const auto& dvariance = Input(1); const auto& X = Input(2); const auto& mean = Input(3); const int ndim = X.dim(); if (axes_.empty()) { axes_.resize(ndim); std::iota(axes_.begin(), axes_.end(), 0); } else { std::sort(axes_.begin(), axes_.end()); CAFFE_ENFORCE_GE(axes_.front(), 0, "Axes ids must be non-negative."); CAFFE_ENFORCE_LT( axes_.back(), ndim, "Axes ids must be smaller than the dimensions of input."); } const std::vector dX_dims(X.sizes().cbegin(), X.sizes().cend()); std::vector dY_dims = dX_dims; for (const int axis : axes_) { dY_dims[axis] = 1; } auto* dX = Output(0, X.sizes(), at::dtype()); return Compute( dY_dims, dX_dims, dmean.template data(), dvariance.template data(), X.template data(), mean.template data(), dX->template mutable_data()); } private: bool Compute( const std::vector& dY_dims, const std::vector& dX_dims, const T* dmean_data, const T* dvariance_data, const T* X_data, const T* mean_data, T* dX_data); std::vector axes_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_MOMENTS_OP_H_