// locally_connected_impl.h is the templated implementation of the // locally_connected.h file. #ifndef CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_ #define CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_ #include #include "caffe2/core/context.h" #include "caffe2/core/flags.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/operators/conv_pool_op_base.h" #include "caffe2/operators/locally_connected_op.h" #include "caffe2/utils/math.h" namespace caffe2 { template bool LocallyConnectedOp::RunOnDeviceWithOrderNCHW() { const auto& X = Input(INPUT); const auto& filter = Input(FILTER); auto* Y = Output(0); const int image_ndim = X.dim() - 2; CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim()); lc_op_util::ShapeParams shape; shape.N = X.dim32(0); shape.C = X.dim32(1); shape.M = filter.dim32(image_ndim); CAFFE_ENFORCE( shape.C == filter.dim32(image_ndim + 1) * group_, "Locally Connected op: input channels does not match: " "# of input channels ", shape.C, " is not equal to kernel channels * group:", filter.dim32(image_ndim + 1), "*", group_); CAFFE_ENFORCE_EQ( shape.M % group_, 0, "The number of output channels is not divisible by group."); ConvPoolOpBase::SetOutputSize(X, Y, shape.M); shape.input_image_size = GetDimsSize(X); shape.output_image_size = GetDimsSize(*Y); const std::vector output_image_dims = GetDims(*Y); for (int i = 0; i < image_ndim; ++i) { CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i)); } int kernel_dims_size = 1; for (std::size_t i = 0; i < kernel_.size(); ++i) { CAFFE_ENFORCE_EQ(filter.dim32(i + image_ndim + 2), kernel_[i]); kernel_dims_size *= kernel_[i]; } shape.X_dims.assign(X.sizes().cbegin() + 1, X.sizes().cend()); shape.kernel_size = shape.C / group_ * kernel_dims_size; lc_op_util::SetColumnBufferShape( shape.N, shape.kernel_size, shape.output_image_size, output_image_dims, order_, &shape.column_slice_dims, &shape.column_dims, &shape.column_transposed_dims, &shape.column_axes); lc_op_util::SetYBufferShape( shape.N, shape.M, shape.output_image_size, order_, &shape.Y_dims, &shape.Y_transposed_dims, &shape.Y_axes); const T* X_data = X.template data(); const T* filter_data = filter.template data(); const T* bias_data = nullptr; if (InputSize() == 3) { const auto& bias = Input(BIAS); CAFFE_ENFORCE_EQ(bias.dim(), image_ndim + 1); for (int i = 0; i < image_ndim; ++i) { CAFFE_ENFORCE_EQ(bias.dim32(i), output_image_dims[i]); } CAFFE_ENFORCE_EQ(bias.dim32(image_ndim), shape.M); bias_data = bias.template data(); ConvPoolOpBase::template SetBiasMultiplier( shape.N, &bias_multiplier_); } T* Y_data = Y->template mutable_data(); RunOnDeviceWithOrderNCHWImpl( shape, X_data, filter_data, bias_data, Y_data, &column_buffer_, &column_transposed_buffer_, &Y_transposed_buffer_); return true; } template bool LocallyConnectedOp::RunOnDeviceWithOrderNHWC() { const auto& X = Input(INPUT); const auto& filter = Input(FILTER); auto* Y = Output(0); CAFFE_ENFORCE_EQ( kernel_.size(), 2, "Only 2d locally connected op is supported for NHWC storage type."); const int image_ndim = X.dim() - 2; CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim()); lc_op_util::ShapeParams shape; shape.N = X.dim32(0); shape.C = X.dim32(3); shape.X_dims = {X.dim32(1), X.dim32(2), X.dim32(3)}; shape.M = filter.dim32(image_ndim); CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 1), kernel_h()); CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 2), kernel_w()); CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 3), shape.C); ConvPoolOpBase::SetOutputSize(X, Y, shape.M); shape.input_image_size = GetDimsSize(X); shape.output_image_size = GetDimsSize(*Y); const std::vector output_image_dims = GetDims(*Y); for (int i = 0; i < image_ndim; ++i) { CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i)); } shape.kernel_size = kernel_h() * kernel_w() * shape.C; lc_op_util::SetColumnBufferShape( shape.N, shape.kernel_size, shape.output_image_size, output_image_dims, order_, &shape.column_slice_dims, &shape.column_dims, &shape.column_transposed_dims, &shape.column_axes); lc_op_util::SetYBufferShape( shape.N, shape.M, shape.output_image_size, order_, &shape.Y_dims, &shape.Y_transposed_dims, &shape.Y_axes); const T* X_data = X.template data(); const T* filter_data = filter.template data(); const T* bias_data = nullptr; if (InputSize() == 3) { const auto& bias = Input(BIAS); CAFFE_ENFORCE_EQ(bias.dim(), image_ndim + 1); for (int i = 0; i < image_ndim; ++i) { CAFFE_ENFORCE_EQ(bias.dim32(i), output_image_dims[i]); } CAFFE_ENFORCE_EQ(bias.dim32(image_ndim), shape.M); bias_data = bias.template data(); ConvPoolOpBase::template SetBiasMultiplier( shape.N, &bias_multiplier_); } T* Y_data = Y->template mutable_data(); RunOnDeviceWithOrderNHWCImpl( shape, X_data, filter_data, bias_data, Y_data, &column_buffer_, &column_transposed_buffer_, &Y_transposed_buffer_); return true; } template void LocallyConnectedOp::RunOnDeviceWithOrderNCHWImpl( const lc_op_util::ShapeParams& shape, const T* X_data, const T* filter_data, const T* bias_data, T* Y_data, Tensor* column_buffer, Tensor* column_transposed_buffer, Tensor* Y_transposed_buffer) { const int input_stride = shape.C / group_ * shape.input_image_size; const int column_stride = shape.kernel_size * shape.output_image_size; column_buffer->Resize(shape.column_dims); column_transposed_buffer->Resize(shape.column_transposed_dims); Y_transposed_buffer->Resize(shape.Y_transposed_dims); T* column_buffer_data = column_buffer->template mutable_data(); T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data(); for (int image_id = 0; image_id < shape.N; ++image_id) { for (int group_id = 0; group_id < group_; ++group_id) { if (kernel_.size() == 2) { math::Im2Col( shape.C / group_, shape.X_dims[1], shape.X_dims[2], kernel_h(), kernel_w(), dilation_h(), dilation_w(), pad_t(), pad_l(), pad_b(), pad_r(), stride_h(), stride_w(), X_data + group_id * input_stride, column_buffer_data + group_id * column_stride, &context_); } else { math::Im2ColNd( kernel_.size(), shape.C * shape.input_image_size, column_stride, shape.X_dims.data(), shape.column_slice_dims.data(), kernel_.data(), stride_.data(), dilation_.data(), pads_.data(), X_data + group_id * input_stride, column_buffer_data + group_id * column_stride, &context_); } } X_data += input_stride * group_; column_buffer_data += column_stride * group_; } math::Transpose( shape.column_dims.size(), shape.column_dims.data(), shape.column_axes.data(), column_buffer->template data(), column_transposed_buffer->template mutable_data(), &context_); math::GemmStridedBatched( CblasNoTrans, CblasNoTrans, shape.output_image_size * group_, shape.M / group_, shape.N, shape.kernel_size, 1.0f, filter_data, shape.M / group_ * shape.kernel_size, column_transposed_buffer->template data(), shape.kernel_size * shape.N, 0.0f, Y_transposed_buffer_data, shape.M / group_ * shape.N, &context_); if (bias_data != nullptr) { math::Gemm( CblasNoTrans, CblasNoTrans, shape.output_image_size * shape.M, shape.N, 1, 1.0, bias_data, bias_multiplier_.template data(), 1.0, Y_transposed_buffer_data, &context_); } math::Transpose( shape.Y_transposed_dims.size(), shape.Y_transposed_dims.data(), shape.Y_axes.data(), Y_transposed_buffer_data, Y_data, &context_); } template void LocallyConnectedOp::RunOnDeviceWithOrderNHWCImpl( const lc_op_util::ShapeParams& shape, const T* X_data, const T* filter_data, const T* bias_data, T* Y_data, Tensor* column_buffer, Tensor* column_transposed_buffer, Tensor* Y_transposed_buffer) { const int input_stride = shape.C * shape.input_image_size; const int column_stride = shape.kernel_size * shape.output_image_size; column_buffer->Resize(shape.column_dims); column_transposed_buffer->Resize(shape.column_transposed_dims); Y_transposed_buffer->Resize(shape.Y_transposed_dims); T* column_buffer_data = column_buffer->template mutable_data(); T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data(); for (int image_id = 0; image_id < shape.N; ++image_id) { math::Im2Col( shape.C, shape.X_dims[0], shape.X_dims[1], kernel_h(), kernel_w(), dilation_h(), dilation_w(), pad_t(), pad_l(), pad_b(), pad_r(), stride_h(), stride_w(), X_data + image_id * input_stride, column_buffer_data + image_id * column_stride, &context_); } math::Transpose( shape.column_dims.size(), shape.column_dims.data(), shape.column_axes.data(), column_buffer->template data(), column_transposed_buffer->template mutable_data(), &context_); math::GemmStridedBatched( CblasNoTrans, CblasTrans, shape.output_image_size, shape.N, shape.M, shape.kernel_size, 1.0f, column_transposed_buffer->template data(), shape.N * shape.kernel_size, filter_data, shape.kernel_size * shape.M, 0.0f, Y_transposed_buffer_data, shape.N * shape.M, &context_); math::Transpose( shape.Y_transposed_dims.size(), shape.Y_transposed_dims.data(), shape.Y_axes.data(), Y_transposed_buffer_data, Y_data, &context_); if (bias_data != nullptr) { math::Gemm( CblasNoTrans, CblasNoTrans, shape.N, shape.output_image_size * shape.M, 1, 1.0f, bias_multiplier_.template data(), bias_data, 1.0f, Y_data, &context_); } } template bool LocallyConnectedGradientOp::RunOnDeviceWithOrderNCHW() { const auto& X = Input(INPUT); const auto& filter = Input(FILTER); const auto& dY = Input(OUTPUT_GRAD); const int image_ndim = X.dim() - 2; CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim()); lc_op_util::ShapeParams shape; shape.N = X.dim32(0); shape.C = X.dim32(1); shape.M = filter.dim32(image_ndim); CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 1) * group_, shape.C); CAFFE_ENFORCE_EQ(shape.M % group_, 0); const std::vector input_image_dims = GetDims(X); shape.input_image_size = GetDimsSize(X); const std::vector output_image_dims = GetDims(dY); shape.output_image_size = GetDimsSize(dY); for (int i = 0; i < image_ndim; ++i) { CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i)); } ConvPoolOpBase::ComputePads(input_image_dims); int kernel_dims_size = 1; for (std::size_t i = 0; i < kernel_.size(); ++i) { CAFFE_ENFORCE_EQ(filter.dim32(i + image_ndim + 2), kernel_[i]); kernel_dims_size *= kernel_[i]; } shape.X_dims.assign(X.sizes().cbegin() + 1, X.sizes().cend()); shape.kernel_size = shape.C / group_ * kernel_dims_size; lc_op_util::SetColumnBufferShape( shape.N, shape.kernel_size, shape.output_image_size, output_image_dims, order_, &shape.column_slice_dims, &shape.column_dims, &shape.column_transposed_dims, &shape.column_axes); lc_op_util::SetYBufferShape( shape.N, shape.M, shape.output_image_size, order_, &shape.Y_dims, &shape.Y_transposed_dims, &shape.Y_axes); auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype()); const T* X_data = X.template data(); const T* filter_data = filter.template data(); const T* dY_data = dY.template data(); T* dfilter_data = dfilter->template mutable_data(); T* dX_data = nullptr; T* dbias_data = nullptr; if (OutputSize() == 3 || (no_bias_ && OutputSize() == 2)) { auto* dX = Output( no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype()); dX_data = dX->template mutable_data(); } if (!no_bias_) { std::vector dbias_dims; std::copy( output_image_dims.cbegin(), output_image_dims.cend(), std::back_inserter(dbias_dims)); dbias_dims.push_back(shape.M); auto* dbias = Output(BIAS_OR_INPUT_GRAD, dbias_dims, at::dtype()); ConvPoolOpBase::template SetBiasMultiplier( shape.N, &bias_multiplier_); dbias_data = dbias->template mutable_data(); } RunOnDeviceWithOrderNCHWImpl( shape, X_data, filter_data, dY_data, dfilter_data, dX_data, dbias_data, &column_buffer_, &column_transposed_buffer_, &dY_transposed_buffer_); return true; } template bool LocallyConnectedGradientOp::RunOnDeviceWithOrderNHWC() { const auto& X = Input(INPUT); const auto& filter = Input(FILTER); const auto& dY = Input(OUTPUT_GRAD); CAFFE_ENFORCE_EQ( kernel_.size(), 2, "Only 2d locally connected op is supported for NHWC storage type."); const int image_ndim = X.dim() - 2; CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim()); lc_op_util::ShapeParams shape; shape.N = X.dim32(0); shape.C = X.dim32(3); shape.X_dims = {X.dim32(1), X.dim32(2), X.dim32(3)}; shape.M = filter.dim32(image_ndim); CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 1), kernel_h()); CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 2), kernel_w()); CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 3), shape.C); const std::vector input_image_dims = {X.dim32(1), X.dim32(2)}; ConvPoolOpBase::ComputePads(input_image_dims); shape.input_image_size = GetDimsSize(X); shape.output_image_size = GetDimsSize(dY); const std::vector output_image_dims = GetDims(dY); for (int i = 0; i < image_ndim; ++i) { CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i)); } shape.kernel_size = kernel_h() * kernel_w() * shape.C; lc_op_util::SetColumnBufferShape( shape.N, shape.kernel_size, shape.output_image_size, output_image_dims, order_, &shape.column_slice_dims, &shape.column_dims, &shape.column_transposed_dims, &shape.column_axes); lc_op_util::SetYBufferShape( shape.N, shape.M, shape.output_image_size, order_, &shape.Y_dims, &shape.Y_transposed_dims, &shape.Y_axes); auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype()); const T* X_data = X.template data(); const T* filter_data = filter.template data(); const T* dY_data = dY.template data(); T* dfilter_data = dfilter->template mutable_data(); T* dX_data = nullptr; T* dbias_data = nullptr; if (OutputSize() == 3 || (no_bias_ && OutputSize() == 2)) { auto* dX = Output( no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype()); dX_data = dX->template mutable_data(); } if (!no_bias_) { std::vector dbias_dims; std::copy( output_image_dims.cbegin(), output_image_dims.cend(), std::back_inserter(dbias_dims)); dbias_dims.push_back(shape.M); auto* dbias = Output(BIAS_OR_INPUT_GRAD, dbias_dims, at::dtype()); ConvPoolOpBase::template SetBiasMultiplier( shape.N, &bias_multiplier_); dbias_data = dbias->template mutable_data(); } RunOnDeviceWithOrderNHWCImpl( shape, X_data, filter_data, dY_data, dfilter_data, dX_data, dbias_data, &column_buffer_, &column_transposed_buffer_, &dY_transposed_buffer_); return true; } template void LocallyConnectedGradientOp::RunOnDeviceWithOrderNCHWImpl( const lc_op_util::ShapeParams& shape, const T* X_data, const T* filter_data, const T* dY_data, T* dfilter_data, T* dX_data, T* dbias_data, Tensor* column_buffer, Tensor* column_transposed_buffer, Tensor* dY_transposed_buffer) { const int input_stride = shape.C * shape.input_image_size; const int column_stride = shape.kernel_size * shape.output_image_size; column_buffer->Resize(shape.column_dims); column_transposed_buffer->Resize(shape.column_transposed_dims); dY_transposed_buffer->Resize(shape.Y_transposed_dims); T* column_buffer_data = column_buffer->template mutable_data(); T* dY_transposed_buffer_data = dY_transposed_buffer->template mutable_data(); for (int image_id = 0; image_id < shape.N; ++image_id) { for (int group_id = 0; group_id < group_; ++group_id) { if (kernel_.size() == 2) { math::Im2Col( shape.C / group_, shape.X_dims[1], shape.X_dims[2], kernel_h(), kernel_w(), dilation_h(), dilation_w(), pad_t(), pad_l(), pad_b(), pad_r(), stride_h(), stride_w(), X_data + group_id * input_stride, column_buffer_data + group_id * column_stride, &context_); } else { math::Im2ColNd( kernel_.size(), shape.C * shape.input_image_size, column_stride, shape.X_dims.data(), shape.column_slice_dims.data(), kernel_.data(), stride_.data(), dilation_.data(), pads_.data(), X_data + group_id * input_stride, column_buffer_data + group_id * column_stride, &context_); } } X_data += input_stride * group_; column_buffer_data += column_stride * group_; } math::Transpose( shape.column_dims.size(), shape.column_dims.data(), shape.column_axes.data(), column_buffer->template data(), column_transposed_buffer->template mutable_data(), &context_); math::Transpose( shape.Y_dims.size(), shape.Y_dims.data(), shape.Y_axes.data(), dY_data, dY_transposed_buffer_data, &context_); // Gradient respect to filter. math::GemmStridedBatched( CblasNoTrans, CblasTrans, shape.output_image_size * group_, shape.M / group_, shape.kernel_size, shape.N, 1.0f, dY_transposed_buffer_data, shape.M / group_ * shape.N, column_transposed_buffer->template data(), shape.N * shape.kernel_size, 0.0f, dfilter_data, shape.M / group_ * shape.kernel_size, &context_); if (dbias_data != nullptr) { // Gradient respect to bias. math::Gemv( CblasNoTrans, shape.output_image_size * shape.M, shape.N, 1.0f, dY_transposed_buffer_data, bias_multiplier_.template data(), 0.0f, dbias_data, &context_); } if (dX_data != nullptr) { // Gradient respect to X. math::GemmStridedBatched( CblasTrans, CblasNoTrans, shape.output_image_size * group_, shape.kernel_size, shape.N, shape.M / group_, 1.0f, filter_data, shape.kernel_size * shape.M / group_, dY_transposed_buffer_data, shape.M / group_ * shape.N, 0.0f, column_transposed_buffer->template mutable_data(), shape.kernel_size * shape.N, &context_); math::Transpose( shape.column_transposed_dims.size(), shape.column_transposed_dims.data(), shape.column_axes.data(), column_transposed_buffer->template data(), column_buffer->template mutable_data(), &context_); const T* const_column_buffer_data = column_buffer->template data(); for (int image_id = 0; image_id < shape.N; ++image_id) { for (int group_id = 0; group_id < group_; ++group_id) { if (kernel_.size() == 2) { math::Col2Im( shape.C / group_, shape.X_dims[1], shape.X_dims[2], kernel_h(), kernel_w(), dilation_h(), dilation_w(), pad_t(), pad_l(), pad_b(), pad_r(), stride_h(), stride_w(), const_column_buffer_data + group_id * column_stride, dX_data + group_id * input_stride, &context_); } else { math::Col2ImNd( kernel_.size(), shape.C * shape.input_image_size, column_stride, shape.X_dims.data(), shape.column_slice_dims.data(), kernel_.data(), stride_.data(), dilation_.data(), pads_.data(), const_column_buffer_data + group_id * column_stride, dX_data + group_id * input_stride, &context_); } } dX_data += input_stride * group_; const_column_buffer_data += column_stride * group_; } } } template void LocallyConnectedGradientOp::RunOnDeviceWithOrderNHWCImpl( const lc_op_util::ShapeParams& shape, const T* X_data, const T* filter_data, const T* dY_data, T* dfilter_data, T* dX_data, T* dbias_data, Tensor* column_buffer, Tensor* column_transposed_buffer, Tensor* dY_transposed_buffer) { const int input_stride = shape.C * shape.input_image_size; const int column_stride = shape.kernel_size * shape.output_image_size; column_buffer->Resize(shape.column_dims); column_transposed_buffer->Resize(shape.column_transposed_dims); dY_transposed_buffer->Resize(shape.Y_transposed_dims); T* column_buffer_data = column_buffer->template mutable_data(); T* dY_transposed_buffer_data = dY_transposed_buffer->template mutable_data(); for (int image_id = 0; image_id < shape.N; ++image_id) { math::Im2Col( shape.C, shape.X_dims[0], shape.X_dims[1], kernel_h(), kernel_w(), dilation_h(), dilation_w(), pad_t(), pad_l(), pad_b(), pad_r(), stride_h(), stride_w(), X_data + image_id * input_stride, column_buffer_data + image_id * column_stride, &context_); } math::Transpose( shape.column_dims.size(), shape.column_dims.data(), shape.column_axes.data(), column_buffer->template data(), column_transposed_buffer->template mutable_data(), &context_); math::Transpose( shape.Y_dims.size(), shape.Y_dims.data(), shape.Y_axes.data(), dY_data, dY_transposed_buffer_data, &context_); // Gradient respect to filter. math::GemmStridedBatched( CblasTrans, CblasNoTrans, shape.output_image_size, shape.M, shape.kernel_size, shape.N, 1.0f, dY_transposed_buffer_data, shape.M * shape.N, column_transposed_buffer->template data(), shape.N * shape.kernel_size, 0.0f, dfilter_data, shape.M * shape.kernel_size, &context_); if (dbias_data != nullptr) { // Gradient respect to bias. math::Gemv( CblasTrans, shape.N, shape.output_image_size * shape.M, 1.0f, dY_data, bias_multiplier_.template data(), 0.0f, dbias_data, &context_); } if (dX_data != nullptr) { // Gradient respect to X. math::GemmStridedBatched( CblasNoTrans, CblasNoTrans, shape.output_image_size, shape.N, shape.kernel_size, shape.M, 1.0f, dY_transposed_buffer_data, shape.N * shape.M, filter_data, shape.M * shape.kernel_size, 0.0f, column_transposed_buffer->template mutable_data(), shape.N * shape.kernel_size, &context_); math::Transpose( shape.column_transposed_dims.size(), shape.column_transposed_dims.data(), shape.column_axes.data(), column_transposed_buffer->template data(), column_buffer->template mutable_data(), &context_); const T* const_column_buffer_data = column_buffer->template data(); for (int image_id = 0; image_id < shape.N; ++image_id) { math::Col2Im( shape.C, shape.X_dims[0], shape.X_dims[1], kernel_h(), kernel_w(), dilation_h(), dilation_w(), pad_t(), pad_l(), pad_b(), pad_r(), stride_h(), stride_w(), const_column_buffer_data, dX_data, &context_); dX_data += input_stride; const_column_buffer_data += column_stride; } } } } // namespace caffe2 #endif // CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_