#ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_ #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_ #include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/utils/conversions.h" #include "caffe2/utils/math.h" #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN #include #endif namespace caffe2 { // This is Caffe's InnerProductOp, with a name that fits its purpose better. template < class Context, class Engine = DefaultEngine, bool TransposeWeight = true> class FullyConnectedOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit FullyConnectedOp(Args&&... args) : Operator(std::forward(args)...), axis_(this->template GetSingleArgument("axis", 1)), axis_w_(this->template GetSingleArgument("axis_w", 1)), float16_compute_( this->template GetSingleArgument("float16_compute", false)) {} ~FullyConnectedOp() {} template < typename T_X, typename T_W, typename T_B, typename T_Y, typename MATH> bool DoRunWithType() { #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN std::chrono::time_point t_very_begin, t_begin, t_end; /* if (VLOG_IS_ON(3)) */ { t_begin = std::chrono::system_clock::now(); t_very_begin = t_begin; } #endif const auto& X = Input(0); const auto& W = Input(1); const auto& b = Input(2); CAFFE_ENFORCE(b.dim() == 1, b.dim()); // batch size const auto canonical_axis = X.canonical_axis_index(axis_); const auto M = X.size_to_dim(canonical_axis); const auto K = X.size_from_dim(canonical_axis); const auto canonical_axis_w = W.canonical_axis_index(axis_w_); const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w) : W.size_from_dim(canonical_axis_w); auto dimErrorString = [&]() { return c10::str( "Dimension mismatch: ", "X: ", X.sizes(), ", W: ", W.sizes(), ", b: ", b.sizes(), ", axis: ", axis_, ", M: ", M, ", N: ", N, ", K: ", K); }; // Error checking CAFFE_ENFORCE(M == X.numel() / K, dimErrorString()); CAFFE_ENFORCE(K == W.numel() / N, dimErrorString()); CAFFE_ENFORCE(N == b.dim32(0), dimErrorString()); CAFFE_ENFORCE(N == b.numel(), dimErrorString()); Y_shape_cache_ = X.sizes().vec(); // This is an invariant of canonical_axis, so we can DCHECK. DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size()); Y_shape_cache_.resize(canonical_axis + 1); Y_shape_cache_[canonical_axis] = N; auto* Y = Output(0, Y_shape_cache_, at::dtype()); CAFFE_ENFORCE(M * N == Y->numel(), dimErrorString()); if (X.numel() == 0) { // skip the rest of the computation if X is empty Y->template mutable_data(); return true; } // default to FLOAT as math.h does. TensorProto::DataType math_type = TensorProto_DataType_FLOAT; if (fp16_type()) { math_type = TensorProto_DataType_FLOAT16; } #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN /* if (VLOG_IS_ON(3)) */ { t_end = std::chrono::system_clock::now(); double dt = std::chrono::duration(t_end - t_begin).count(); LOG(INFO) << "@PERF this=" << this << " before_gemm: " << dt * 1e3 << " ms"; t_begin = std::chrono::system_clock::now(); } #endif // W * x math::Gemm( CblasNoTrans, TransposeWeight ? CblasTrans : CblasNoTrans, M, N, K, 1, X.template data(), W.template data(), 0, Y->template mutable_data(), &context_, math_type); #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN /* if (VLOG_IS_ON(3)) */ { t_end = std::chrono::system_clock::now(); double dt = std::chrono::duration(t_end - t_begin).count(); LOG(INFO) << "@PERF this=" << this << " gemm: " << dt * 1e3 << " ms"; t_begin = std::chrono::system_clock::now(); } #endif // Add bias term if (!bias_multiplier_.has_value()) { bias_multiplier_ = caffe2::empty({M}, at::dtype().device(Context::GetDeviceType())); math::Set( M, convert::To(1), bias_multiplier_->template mutable_data(), &context_); } else if (bias_multiplier_->numel() != M) { bias_multiplier_->Resize(M); math::Set( M, convert::To(1), bias_multiplier_->template mutable_data(), &context_); } math::Gemm( CblasNoTrans, CblasNoTrans, M, N, 1, 1, bias_multiplier_->template data(), b.template data(), 1, Y->template mutable_data(), &context_, math_type); #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN /* if (VLOG_IS_ON(3)) */ { t_end = std::chrono::system_clock::now(); double dt = std::chrono::duration(t_end - t_begin).count(); LOG(INFO) << "@PERF this=" << this << " add_bias : " << dt * 1e3 << " ms"; t_begin = std::chrono::system_clock::now(); } #endif return true; } bool RunOnDevice() override { return DoRunWithType< float, // X float, // W float, // B float, // Y float>(); // Math } protected: size_t axis_{1}; size_t axis_w_{1}; // A local vector to cache the output shape so we don't need to recreate // a vector object every time we run Run(). vector Y_shape_cache_; c10::optional bias_multiplier_; bool float16_compute_; }; template < class Context, class Engine = DefaultEngine, bool TransposeWeight = true> class FullyConnectedGradientOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit FullyConnectedGradientOp(Args&&... args) : Operator(std::forward(args)...), axis_(this->template GetSingleArgument("axis", 1)), axis_w_(this->template GetSingleArgument("axis_w", 1)), float16_compute_( this->template GetSingleArgument("float16_compute", false)) {} ~FullyConnectedGradientOp() {} template < typename T_X, typename T_W, typename T_DY, typename T_B, typename T_DX, typename T_DW, typename T_DB, typename MATH> bool DoRunWithType() { const auto& X = Input(0); const auto& W = Input(1); const auto& dY = Input(2); // batch size const auto canonical_axis = X.canonical_axis_index(axis_); const int M = X.size_to_dim(canonical_axis); const int K = X.size_from_dim(canonical_axis); const auto canonical_axis_w = W.canonical_axis_index(axis_w_); const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w) : W.size_from_dim(canonical_axis_w); auto dimErrorString = [&]() { return c10::str( "Dimension mismatch: ", "X: ", X.sizes(), ", W: ", W.sizes(), ", dY: ", dY.sizes(), ", axis: ", axis_, ", M: ", M, ", N: ", N, ", K: ", K); }; CAFFE_ENFORCE(M * K == X.numel(), dimErrorString()); CAFFE_ENFORCE(K * N == W.numel(), dimErrorString()); auto* dW = Output(0, W.sizes(), at::dtype()); auto* db = Output(1, {N}, at::dtype()); if (X.numel() == 0) { // generate a zero blob for db and dW when X is empty math::Set( db->numel(), convert::To(0), db->template mutable_data(), &context_); math::Set( dW->numel(), convert::To(0), dW->template mutable_data(), &context_); if (OutputSize() == 3) { Output(2, X.sizes(), at::dtype()); } return true; } // default to FLOAT as math.h does. TensorProto::DataType math_type = TensorProto_DataType_FLOAT; if (fp16_type()) { math_type = TensorProto_DataType_FLOAT16; } // Compute dW math::Gemm( CblasTrans, CblasNoTrans, TransposeWeight ? N : K, TransposeWeight ? K : N, M, 1, TransposeWeight ? dY.template data() : X.template data(), TransposeWeight ? X.template data() : dY.template data(), 0, dW->template mutable_data(), &context_, math_type); if (!bias_multiplier_.has_value()) { bias_multiplier_ = caffe2::empty({M}, at::dtype().device(Context::GetDeviceType())); math::Set( M, convert::To(1), bias_multiplier_->template mutable_data(), &context_); } else if (bias_multiplier_->numel() != M) { bias_multiplier_->Resize(M); math::Set( M, convert::To(1), bias_multiplier_->template mutable_data(), &context_); } // Compute dB math::Gemv( CblasTrans, M, N, 1, dY.template data(), bias_multiplier_->template data(), 0, db->template mutable_data(), &context_); // Compute dX if (OutputSize() == 3) { auto* dX = Output(2, X.sizes(), at::dtype()); math::Gemm( CblasNoTrans, TransposeWeight ? CblasNoTrans : CblasTrans, M, K, N, 1, dY.template data(), W.template data(), 0, dX->template mutable_data(), &context_, math_type); } return true; } bool RunOnDevice() override { return DoRunWithType< float, // X float, // W float, // dY float, // B float, // dX float, // dW float, // dB float>(); // Math } protected: size_t axis_{1}; size_t axis_w_{1}; c10::optional bias_multiplier_; bool float16_compute_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_