#ifndef CAFFE2_OPERATORS_LAYER_NORM_OP_H_ #define CAFFE2_OPERATORS_LAYER_NORM_OP_H_ #include #include #include "caffe2/core/context.h" #include "caffe2/core/export_caffe2_op_to_c10.h" #include "caffe2/core/operator.h" #include "caffe2/core/types.h" #include "caffe2/utils/math.h" C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(LayerNorm) namespace caffe2 { template class LayerNormOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit LayerNormOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(int, "axis", axis_, 1), OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f), OP_SINGLE_ARG(bool, "elementwise_affine", elementwise_affine_, false) {} bool RunOnDevice() override { return DispatchHelper>::call(this, Input(0)); } template bool DoRunWithType() { const auto& X = Input(0); auto* Y = Output(0); CAFFE_ENFORCE_GE(X.dim(), 2, "LayerNorm requires input dim >= 2."); const int canonical_axis = X.canonical_axis_index(axis_); std::vector moments_dims( X.sizes().cbegin(), X.sizes().cbegin() + canonical_axis); moments_dims.push_back(1); auto* mean = Output(1, moments_dims, at::dtype()); auto* sigma = Output(2, moments_dims, at::dtype()); const int M = X.size_to_dim(canonical_axis); const int N = X.size_from_dim(canonical_axis); Y->ResizeLike(X); scale_.Resize(M); bias_.Resize(M); const T* X_data = X.template data(); T* Y_data = Y->template mutable_data(); T* mean_data = mean->template mutable_data(); T* sigma_data = sigma->template mutable_data(); T* scale_data = scale_.template mutable_data(); T* bias_data = bias_.template mutable_data(); const std::array X_dims = {M, N}; const std::array Y_dims = {M, 1}; math::Moments( 2, X_dims.data(), Y_dims.data(), X_data, mean_data, sigma_data, &context_); ComputeSigmaAndFusedParams( M, epsilon_, mean_data, sigma_data, sigma_data, scale_data, bias_data); const T* gamma_data = nullptr; const T* beta_data = nullptr; if (elementwise_affine_) { CAFFE_ENFORCE_EQ(InputSize(), 3); const auto& gamma = Input(1); const auto& beta = Input(2); CAFFE_ENFORCE_EQ(gamma.numel(), N); CAFFE_ENFORCE_EQ(beta.numel(), N); gamma_data = gamma.template data(); beta_data = beta.template data(); } LayerNormForward( M, N, X_data, scale_data, bias_data, gamma_data, beta_data, Y_data); return true; } private: template void ComputeSigmaAndFusedParams( const int N, const float eps, const T* mean, const T* var, T* stddev, T* scale, T* bias); template void LayerNormForward( const int M, const int N, const T* X, const T* scale, const T* bias, const T* gamma, const T* beta, T* Y); const int axis_; const float epsilon_; const bool elementwise_affine_; Tensor scale_{Context::GetDeviceType()}; Tensor bias_{Context::GetDeviceType()}; }; template class LayerNormGradientOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit LayerNormGradientOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(int, "axis", axis_, 1), OP_SINGLE_ARG(bool, "elementwise_affine", elementwise_affine_, false) {} bool RunOnDevice() override { return DispatchHelper>::call(this, Input(0)); } template bool DoRunWithType() { const auto& dY = Input(0); const auto& Y = Input(1); const auto& mean = Input(2); const auto& sigma = Input(3); const auto& X = Input(4); const int canonical_axis = X.canonical_axis_index(axis_); const int M = X.size_to_dim(canonical_axis); const int N = X.size_from_dim(canonical_axis); auto* dX = Output(0, X.sizes(), at::dtype()); ReinitializeTensor( &ds_, {M}, at::dtype().device(Context::GetDeviceType())); ReinitializeTensor( &db_, {M}, at::dtype().device(Context::GetDeviceType())); ReinitializeTensor( &rstd_, {M}, at::dtype().device(Context::GetDeviceType())); ReinitializeTensor( &X_scale_, {M}, at::dtype().device(Context::GetDeviceType())); ReinitializeTensor( &bias_, {M}, at::dtype().device(Context::GetDeviceType())); const T* dY_data = dY.template data(); const T* X_data = X.template data(); const T* mean_data = mean.template data(); const T* sigma_data = sigma.template data(); T* dX_data = dX->template mutable_data(); T* ds_data = ds_.template mutable_data(); T* db_data = db_.template mutable_data(); T* rstd_data = rstd_.template mutable_data(); T* X_scale_data = X_scale_.template mutable_data(); T* bias_data = bias_.template mutable_data(); const T* gamma_data = nullptr; T* dgamma_data = nullptr; T* dbeta_data = nullptr; T* g_scale_data = nullptr; if (elementwise_affine_) { const auto& gamma = Input(5); auto* dgamma = Output(1, gamma.sizes(), at::dtype()); auto* dbeta = Output(2, gamma.sizes(), at::dtype()); ReinitializeTensor( &g_scale_, {M}, at::dtype().device(Context::GetDeviceType())); gamma_data = gamma.template data(); dgamma_data = dgamma->template mutable_data(); dbeta_data = dbeta->template mutable_data(); g_scale_data = g_scale_.template mutable_data(); } ComputeInternalGradients( M, N, dY_data, X_data, gamma_data, dX_data, ds_data, db_data); ComputeFusedParams( M, N, mean_data, sigma_data, ds_data, db_data, rstd_data, X_scale_data, bias_data, g_scale_data); if (elementwise_affine_) { GammaBetaBackward( M, N, dX_data, dY_data, rstd_data, g_scale_data, dgamma_data, dbeta_data); } LayerNormBackward( M, N, dY_data, X_data, gamma_data, rstd_data, X_scale_data, bias_data, dX_data); return true; } private: template void ComputeInternalGradients( const int M, const int N, const T* dY, const T* X, const T* gamma, T* dYxX, T* ds, T* db); template void ComputeFusedParams( const int M, const int N, const T* mean, const T* sigma, const T* ds, const T* db, T* rstd, T* X_scale, T* bias, T* g_scale); template void LayerNormBackward( const int M, const int N, const T* dY, const T* X, const T* gamma, const T* dY_scale, const T* X_scale, const T* bias, T* dX); template void GammaBetaBackward( const int M, const int N, const T* dYxX, const T* dY, const T* rstd, const T* g_scale, T* dgamma, T* dbeta); const int axis_; const bool elementwise_affine_; Tensor ds_; Tensor db_; Tensor rstd_; Tensor X_scale_; Tensor bias_; Tensor g_scale_; Tensor ones_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_LAYER_NORM_OP_H_