#ifndef CAFFE2_OPERATORS_GROUP_NORM_OP_H_ #define CAFFE2_OPERATORS_GROUP_NORM_OP_H_ #include #include #include #include "caffe2/core/common.h" #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class GroupNormOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit GroupNormOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(int, "group", group_, 32), OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5), order_(StringToStorageOrder( this->template GetSingleArgument("order", "NCHW"))), OP_SINGLE_ARG(bool, OpSchema::Arg_IsTest, is_test_, true) { CAFFE_ENFORCE_NE( order_, StorageOrder::UNKNOWN, "order should be either \"NCHW\" or \"NHWC\"."); if (!is_test_) { CAFFE_ENFORCE_EQ(OutputSize(), 3); } } bool RunOnDevice() override { const auto& X = Input(INPUT); const auto& gamma = Input(GAMMA); const auto& beta = Input(BETA); const int ndim = X.dim(); const int N = X.dim32(0); const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1); const size_t HxW = order_ == StorageOrder::NCHW ? X.size_from_dim(2) : X.size_between_dim(0, ndim - 1); CAFFE_ENFORCE_EQ(C % group_, 0); CAFFE_ENFORCE_EQ(gamma.numel(), C); CAFFE_ENFORCE_EQ(beta.numel(), C); const int G = group_; const int K = C / G; auto* Y = Output(OUTPUT, X.sizes(), at::dtype()); if (N == 0) { return true; } T* mu_data = nullptr; T* rsig_data = nullptr; if (OutputSize() == 3) { auto* mu = Output(MU, {N, G}, at::dtype()); auto* rsig = Output(INV_SIGMA, {N, G}, at::dtype()); mu_data = mu->template mutable_data(); rsig_data = rsig->template mutable_data(); } else { ReinitializeTensor( &mu_, {N, G}, at::dtype().device(Context::GetDeviceType())); ReinitializeTensor( &rsig_, {N, G}, at::dtype().device(Context::GetDeviceType())); mu_data = mu_.template mutable_data(); rsig_data = rsig_.template mutable_data(); } if (order_ == StorageOrder::NCHW) { return RunOnDeviceWithOrderNCHW( N, G, K, HxW, X.template data(), gamma.template data(), beta.template data(), Y->template mutable_data(), mu_data, rsig_data); } else { return RunOnDeviceWithOrderNHWC( N, G, K, HxW, X.template data(), gamma.template data(), beta.template data(), Y->template mutable_data(), mu_data, rsig_data); } } private: bool RunOnDeviceWithOrderNCHW( const int N, const int G, const int K, const int HxW, const T* X, const T* gamma, const T* beta, T* Y, T* mu, T* rsig) { const int C = G * K; ReinitializeTensor( &scale_, {N, C}, at::dtype().device(Context::GetDeviceType())); ReinitializeTensor( &bias_, {N, C}, at::dtype().device(Context::GetDeviceType())); T* scale_data = scale_.template mutable_data(); T* bias_data = bias_.template mutable_data(); const std::array X_dims = {N * G, K * HxW}; const std::array Y_dims = {N * G, 1}; math::Moments( 2, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_); math::InvStd( N * G, static_cast(epsilon_), rsig, rsig, &context_); ComputeFusedParams(N, G, K, mu, rsig, gamma, beta, scale_data, bias_data); GroupNormForwardNCHW(N, C, HxW, X, scale_data, bias_data, Y); return true; } bool RunOnDeviceWithOrderNHWC( const int N, const int G, const int K, const int HxW, const T* X, const T* gamma, const T* beta, T* Y, T* mu, T* rsig) { const int C = G * K; ReinitializeTensor( &scale_, {N, C}, at::dtype().device(Context::GetDeviceType())); ReinitializeTensor( &bias_, {N, C}, at::dtype().device(Context::GetDeviceType())); T* scale_data = scale_.template mutable_data(); T* bias_data = bias_.template mutable_data(); const std::array X_dims = {N, HxW, G, K}; const std::array Y_dims = {N, 1, G, 1}; math::Moments( 4, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_); math::InvStd( N * G, static_cast(epsilon_), rsig, rsig, &context_); ComputeFusedParams(N, G, K, mu, rsig, gamma, beta, scale_data, bias_data); GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y); return true; } void ComputeFusedParams( int N, int G, int K, const T* mu, const T* rsig, const T* gamma, const T* beta, T* scale, T* bias); void GroupNormForwardNCHW( const int N, const int C, const int HxW, const T* X, const T* scale, const T* bias, T* Y); void GroupNormForwardNHWC( const int N, const int C, const int HxW, const T* X, const T* scale, const T* bias, T* Y); const int group_; const float epsilon_; const StorageOrder order_; const bool is_test_; Tensor mu_; Tensor rsig_; Tensor scale_; Tensor bias_; // Input: X, gamma, beta // Output: Y, mu, inv_sig INPUT_TAGS(INPUT, GAMMA, BETA); OUTPUT_TAGS(OUTPUT, MU, INV_SIGMA); }; template class GroupNormGradientOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit GroupNormGradientOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(int, "group", group_, 32), order_(StringToStorageOrder( this->template GetSingleArgument("order", "NCHW"))) { CAFFE_ENFORCE_NE( order_, StorageOrder::UNKNOWN, "order should be either \"NCHW\" or \"NHWC\"."); } bool RunOnDevice() override { const auto& dY = Input(OUTPUT_GRAD); const auto& X = Input(INPUT); const auto& gamma = Input(GAMMA); const auto& beta = Input(BETA); const auto& mu = Input(MU); const auto& rsig = Input(INV_SIGMA); const int ndim = X.dim(); const int N = X.dim32(0); const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1); const int HxW = X.numel() / (N * C); CAFFE_ENFORCE_EQ(C % group_, 0); CAFFE_ENFORCE_EQ(gamma.numel(), C); CAFFE_ENFORCE_EQ(beta.numel(), C); const int G = group_; const int K = C / G; auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype()); auto* dgamma = Output(GAMMA_GRAD, gamma.sizes(), at::dtype()); auto* dbeta = Output(BETA_GRAD, beta.sizes(), at::dtype()); if (order_ == StorageOrder::NCHW) { return RunOnDeviceWithOrderNCHW( N, G, K, HxW, dY.template data(), X.template data(), mu.template data(), rsig.template data(), gamma.template data(), dX->template mutable_data(), dgamma->template mutable_data(), dbeta->template mutable_data()); } else { return RunOnDeviceWithOrderNHWC( N, G, K, HxW, dY.template data(), X.template data(), mu.template data(), rsig.template data(), gamma.template data(), dX->template mutable_data(), dgamma->template mutable_data(), dbeta->template mutable_data()); } } protected: bool RunOnDeviceWithOrderNCHW( int N, int G, int K, int HxW, const T* dY_data, const T* X_data, const T* mu_data, const T* rsig_data, const T* gamma_data, T* dX_data, T* dgamma_data, T* dbeta_data); bool RunOnDeviceWithOrderNHWC( int N, int G, int K, int HxW, const T* dY_data, const T* X_data, const T* mu_data, const T* rsig_data, const T* gamma_data, T* dX_data, T* dgamma_data, T* dbeta_data); const int group_; const StorageOrder order_; Tensor ds_; Tensor db_; Tensor dY_scale_; Tensor X_scale_; Tensor bias_; Tensor ones_; // Input: dY, X, gamma, beta, mu, inv_sig // Output: dX, dgamma, dbeta INPUT_TAGS(OUTPUT_GRAD, INPUT, GAMMA, BETA, MU, INV_SIGMA); OUTPUT_TAGS(INPUT_GRAD, GAMMA_GRAD, BETA_GRAD); }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_GROUP_NORM_OP_H_