#ifndef CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
|
#define CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
|
|
#include <array>
|
|
#include "caffe2/core/context.h"
|
#include "caffe2/core/operator.h"
|
#include "caffe2/utils/math.h"
|
|
namespace caffe2 {
|
|
template <typename T, class Context>
|
class InstanceNormOp final : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
template <class... Args>
|
explicit InstanceNormOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...),
|
OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5),
|
order_(StringToStorageOrder(
|
this->template GetSingleArgument<string>("order", "NCHW"))) {
|
CAFFE_ENFORCE_GE(epsilon_, 0, "Must pass a nonnegative epsilon.");
|
CAFFE_ENFORCE_NE(
|
order_,
|
StorageOrder::UNKNOWN,
|
"order should be either \"NCHW\" or \"NHWC\".");
|
}
|
|
bool RunOnDevice() {
|
const auto& X = Input(INPUT);
|
const auto& gamma = Input(SCALE);
|
const auto& beta = Input(BIAS);
|
const int ndim = X.dim();
|
const int64_t N = X.dim(0);
|
const int64_t C = order_ == StorageOrder::NCHW ? X.dim(1) : X.dim(ndim - 1);
|
const int64_t HxW = X.numel() / (N * C);
|
CAFFE_ENFORCE_EQ(gamma.numel(), C);
|
CAFFE_ENFORCE_EQ(beta.numel(), C);
|
auto* Y = Output(OUTPUT, X.sizes(), at::dtype<T>());
|
const T* X_data = X.template data<T>();
|
const T* gamma_data = gamma.template data<T>();
|
const T* beta_data = beta.template data<T>();
|
T* Y_data = Y->template mutable_data<T>();
|
T* mean_data = nullptr;
|
T* rstd_data = nullptr;
|
if (OutputSize() >= 2) {
|
auto* mean = Output(MEAN, {N, C}, at::dtype<T>());
|
mean_data = mean->template mutable_data<T>();
|
} else {
|
ReinitializeTensor(
|
&mean_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
|
mean_data = mean_.template mutable_data<T>();
|
}
|
if (OutputSize() >= 3) {
|
auto* rstd = Output(RSTD, {N, C}, at::dtype<T>());
|
rstd_data = rstd->template mutable_data<T>();
|
} else {
|
ReinitializeTensor(
|
&rstd_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
|
rstd_data = rstd_.template mutable_data<T>();
|
}
|
switch (order_) {
|
case StorageOrder::NCHW: {
|
return RunOnDeviceWithOrderNCHW(
|
N,
|
C,
|
HxW,
|
X_data,
|
gamma_data,
|
beta_data,
|
Y_data,
|
mean_data,
|
rstd_data);
|
}
|
case StorageOrder::NHWC: {
|
return RunOnDeviceWithOrderNHWC(
|
N,
|
C,
|
HxW,
|
X_data,
|
gamma_data,
|
beta_data,
|
Y_data,
|
mean_data,
|
rstd_data);
|
}
|
default: {
|
CAFFE_THROW("Unknown storage order: ", order_);
|
}
|
}
|
}
|
|
private:
|
bool RunOnDeviceWithOrderNCHW(
|
int64_t N,
|
int64_t C,
|
int64_t HxW,
|
const T* X,
|
const T* gamma,
|
const T* beta,
|
T* Y,
|
T* mean,
|
T* rstd);
|
|
bool RunOnDeviceWithOrderNHWC(
|
int64_t N,
|
int64_t C,
|
int64_t HxW,
|
const T* X,
|
const T* gamma,
|
const T* beta,
|
T* Y,
|
T* mean,
|
T* rstd);
|
|
const float epsilon_;
|
const StorageOrder order_;
|
|
Tensor mean_;
|
Tensor rstd_;
|
Tensor scale_;
|
Tensor bias_;
|
|
INPUT_TAGS(INPUT, SCALE, BIAS);
|
OUTPUT_TAGS(OUTPUT, MEAN, RSTD);
|
};
|
|
template <typename T, class Context>
|
class InstanceNormGradientOp final : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
template <class... Args>
|
explicit InstanceNormGradientOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...),
|
OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5),
|
order_(StringToStorageOrder(
|
this->template GetSingleArgument<string>("order", "NCHW"))) {
|
CAFFE_ENFORCE_GE(epsilon_, 0, "Must pass a nonnegative epsilon.");
|
CAFFE_ENFORCE_NE(
|
order_,
|
StorageOrder::UNKNOWN,
|
"order should be either \"NCHW\" or \"NHWC\".");
|
}
|
|
bool RunOnDevice() {
|
const auto& X = Input(INPUT);
|
const auto& gamma = Input(SCALE);
|
const auto& dY = Input(OUTPUT_GRAD);
|
const int ndim = X.dim();
|
const int64_t N = X.dim(0);
|
const int64_t C = order_ == StorageOrder::NCHW ? X.dim(1) : X.dim(ndim - 1);
|
const int64_t HxW = X.numel() / (N * C);
|
CAFFE_ENFORCE_EQ(gamma.numel(), C);
|
const T* dY_data = dY.template data<T>();
|
const T* X_data = X.template data<T>();
|
const T* gamma_data = gamma.template data<T>();
|
const T* mean_data = nullptr;
|
const T* rstd_data = nullptr;
|
CAFFE_ENFORCE_GE(InputSize(), 4);
|
CAFFE_ENFORCE_LE(InputSize(), 6);
|
if (InputSize() == 6) {
|
const auto& mean = Input(MEAN);
|
const auto& rstd = Input(RSTD);
|
mean_data = mean.template data<T>();
|
rstd_data = rstd.template data<T>();
|
} else {
|
ReinitializeTensor(
|
&mean_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
|
ReinitializeTensor(
|
&rstd_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
|
ComputeMoments(
|
N,
|
C,
|
HxW,
|
X_data,
|
mean_.template mutable_data<T>(),
|
rstd_.template mutable_data<T>());
|
mean_data = mean_.template data<T>();
|
rstd_data = rstd_.template data<T>();
|
}
|
|
auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype<T>());
|
auto* dgamma = Output(SCALE_GRAD, gamma.sizes(), at::dtype<T>());
|
auto* dbeta = Output(BIAS_GRAD, gamma.sizes(), at::dtype<T>());
|
T* dX_data = dX->template mutable_data<T>();
|
T* dgamma_data = dgamma->template mutable_data<T>();
|
T* dbeta_data = dbeta->template mutable_data<T>();
|
|
switch (order_) {
|
case StorageOrder::NCHW: {
|
return RunOnDeviceWithOrderNCHW(
|
N,
|
C,
|
HxW,
|
dY_data,
|
X_data,
|
mean_data,
|
rstd_data,
|
gamma_data,
|
dX_data,
|
dgamma_data,
|
dbeta_data);
|
}
|
case StorageOrder::NHWC: {
|
return RunOnDeviceWithOrderNHWC(
|
N,
|
C,
|
HxW,
|
dY_data,
|
X_data,
|
mean_data,
|
rstd_data,
|
gamma_data,
|
dX_data,
|
dgamma_data,
|
dbeta_data);
|
}
|
default: {
|
CAFFE_THROW("Unknown storage order: ", order_);
|
}
|
}
|
}
|
|
private:
|
void ComputeMoments(
|
int64_t N,
|
int64_t C,
|
int64_t HxW,
|
const T* X,
|
T* mean,
|
T* rstd);
|
|
bool RunOnDeviceWithOrderNCHW(
|
int64_t N,
|
int64_t C,
|
int64_t HxW,
|
const T* dY,
|
const T* X,
|
const T* mean,
|
const T* rstd,
|
const T* gamma,
|
T* dX,
|
T* dgamma,
|
T* dbeta);
|
|
bool RunOnDeviceWithOrderNHWC(
|
int64_t N,
|
int64_t C,
|
int64_t HxW,
|
const T* dY,
|
const T* X,
|
const T* mean,
|
const T* rstd,
|
const T* gamma,
|
T* dX,
|
T* dgamma,
|
T* dbeta);
|
|
const float epsilon_;
|
const StorageOrder order_;
|
|
Tensor mean_;
|
Tensor rstd_;
|
Tensor ds_;
|
Tensor db_;
|
Tensor c1_;
|
Tensor c2_;
|
Tensor c3_;
|
Tensor ones_;
|
|
INPUT_TAGS(INPUT, SCALE, BIAS, OUTPUT_GRAD, MEAN, RSTD);
|
OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
|
};
|
|
} // namespace caffe2
|
|
#endif // CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
|