#ifndef CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_ #define CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_ #include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class InstanceNormOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit InstanceNormOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5), order_(StringToStorageOrder( this->template GetSingleArgument("order", "NCHW"))) { CAFFE_ENFORCE_GE(epsilon_, 0, "Must pass a nonnegative epsilon."); CAFFE_ENFORCE_NE( order_, StorageOrder::UNKNOWN, "order should be either \"NCHW\" or \"NHWC\"."); } bool RunOnDevice() { const auto& X = Input(INPUT); const auto& gamma = Input(SCALE); const auto& beta = Input(BIAS); const int ndim = X.dim(); const int64_t N = X.dim(0); const int64_t C = order_ == StorageOrder::NCHW ? X.dim(1) : X.dim(ndim - 1); const int64_t HxW = X.numel() / (N * C); CAFFE_ENFORCE_EQ(gamma.numel(), C); CAFFE_ENFORCE_EQ(beta.numel(), C); auto* Y = Output(OUTPUT, X.sizes(), at::dtype()); const T* X_data = X.template data(); const T* gamma_data = gamma.template data(); const T* beta_data = beta.template data(); T* Y_data = Y->template mutable_data(); T* mean_data = nullptr; T* rstd_data = nullptr; if (OutputSize() >= 2) { auto* mean = Output(MEAN, {N, C}, at::dtype()); mean_data = mean->template mutable_data(); } else { ReinitializeTensor( &mean_, {N, C}, at::dtype().device(Context::GetDeviceType())); mean_data = mean_.template mutable_data(); } if (OutputSize() >= 3) { auto* rstd = Output(RSTD, {N, C}, at::dtype()); rstd_data = rstd->template mutable_data(); } else { ReinitializeTensor( &rstd_, {N, C}, at::dtype().device(Context::GetDeviceType())); rstd_data = rstd_.template mutable_data(); } switch (order_) { case StorageOrder::NCHW: { return RunOnDeviceWithOrderNCHW( N, C, HxW, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); } case StorageOrder::NHWC: { return RunOnDeviceWithOrderNHWC( N, C, HxW, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); } default: { CAFFE_THROW("Unknown storage order: ", order_); } } } private: bool RunOnDeviceWithOrderNCHW( int64_t N, int64_t C, int64_t HxW, const T* X, const T* gamma, const T* beta, T* Y, T* mean, T* rstd); bool RunOnDeviceWithOrderNHWC( int64_t N, int64_t C, int64_t HxW, const T* X, const T* gamma, const T* beta, T* Y, T* mean, T* rstd); const float epsilon_; const StorageOrder order_; Tensor mean_; Tensor rstd_; Tensor scale_; Tensor bias_; INPUT_TAGS(INPUT, SCALE, BIAS); OUTPUT_TAGS(OUTPUT, MEAN, RSTD); }; template class InstanceNormGradientOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit InstanceNormGradientOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5), order_(StringToStorageOrder( this->template GetSingleArgument("order", "NCHW"))) { CAFFE_ENFORCE_GE(epsilon_, 0, "Must pass a nonnegative epsilon."); CAFFE_ENFORCE_NE( order_, StorageOrder::UNKNOWN, "order should be either \"NCHW\" or \"NHWC\"."); } bool RunOnDevice() { const auto& X = Input(INPUT); const auto& gamma = Input(SCALE); const auto& dY = Input(OUTPUT_GRAD); const int ndim = X.dim(); const int64_t N = X.dim(0); const int64_t C = order_ == StorageOrder::NCHW ? X.dim(1) : X.dim(ndim - 1); const int64_t HxW = X.numel() / (N * C); CAFFE_ENFORCE_EQ(gamma.numel(), C); const T* dY_data = dY.template data(); const T* X_data = X.template data(); const T* gamma_data = gamma.template data(); const T* mean_data = nullptr; const T* rstd_data = nullptr; CAFFE_ENFORCE_GE(InputSize(), 4); CAFFE_ENFORCE_LE(InputSize(), 6); if (InputSize() == 6) { const auto& mean = Input(MEAN); const auto& rstd = Input(RSTD); mean_data = mean.template data(); rstd_data = rstd.template data(); } else { ReinitializeTensor( &mean_, {N, C}, at::dtype().device(Context::GetDeviceType())); ReinitializeTensor( &rstd_, {N, C}, at::dtype().device(Context::GetDeviceType())); ComputeMoments( N, C, HxW, X_data, mean_.template mutable_data(), rstd_.template mutable_data()); mean_data = mean_.template data(); rstd_data = rstd_.template data(); } auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype()); auto* dgamma = Output(SCALE_GRAD, gamma.sizes(), at::dtype()); auto* dbeta = Output(BIAS_GRAD, gamma.sizes(), at::dtype()); T* dX_data = dX->template mutable_data(); T* dgamma_data = dgamma->template mutable_data(); T* dbeta_data = dbeta->template mutable_data(); switch (order_) { case StorageOrder::NCHW: { return RunOnDeviceWithOrderNCHW( N, C, HxW, dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_data, dbeta_data); } case StorageOrder::NHWC: { return RunOnDeviceWithOrderNHWC( N, C, HxW, dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_data, dbeta_data); } default: { CAFFE_THROW("Unknown storage order: ", order_); } } } private: void ComputeMoments( int64_t N, int64_t C, int64_t HxW, const T* X, T* mean, T* rstd); bool RunOnDeviceWithOrderNCHW( int64_t N, int64_t C, int64_t HxW, const T* dY, const T* X, const T* mean, const T* rstd, const T* gamma, T* dX, T* dgamma, T* dbeta); bool RunOnDeviceWithOrderNHWC( int64_t N, int64_t C, int64_t HxW, const T* dY, const T* X, const T* mean, const T* rstd, const T* gamma, T* dX, T* dgamma, T* dbeta); const float epsilon_; const StorageOrder order_; Tensor mean_; Tensor rstd_; Tensor ds_; Tensor db_; Tensor c1_; Tensor c2_; Tensor c3_; Tensor ones_; INPUT_TAGS(INPUT, SCALE, BIAS, OUTPUT_GRAD, MEAN, RSTD); OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD); }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_