#ifndef CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_ #define CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_ #include #include #include #include #include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class BatchMatMulOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit BatchMatMulOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(bool, "trans_a", trans_a_, false), OP_SINGLE_ARG(bool, "trans_b", trans_b_, false), OP_SINGLE_ARG(bool, "broadcast", broadcast_, false) {} bool RunOnDevice() override { return DispatchHelper>::call(this, Input(0)); } template bool DoRunWithType() { const auto& A = Input(0); const auto& B = Input(1); const int A_ndim = A.dim(); const int B_ndim = B.dim(); const std::vector A_dims = A.sizes().vec(); const std::vector B_dims = B.sizes().vec(); const T* A_data = A.template data(); const T* B_data = B.template data(); if (A_ndim == 1 && B_ndim == 1) { CAFFE_ENFORCE_EQ(A.numel(), B.numel()); auto* Y = Output(0, {1}, at::dtype()); T* Y_data = Y->template mutable_data(); math::Dot(A.numel(), A_data, B_data, Y_data, &context_); return true; } if (A_ndim == 1) { const int N = A.numel(); if (trans_b_) { CAFFE_ENFORCE_EQ(B_dims[B_ndim - 1], N); } else { CAFFE_ENFORCE_EQ(B_dims[B_ndim - 2], N); } std::vector Y_dims(B_ndim - 1); if (trans_b_) { std::copy_n(B_dims.cbegin(), B_ndim - 1, Y_dims.begin()); } else { std::copy_n(B_dims.cbegin(), B_ndim - 2, Y_dims.begin()); Y_dims.back() = B_dims.back(); } auto* Y = Output(0, Y_dims, at::dtype()); T* Y_data = Y->template mutable_data(); if (trans_b_) { const int M = B.numel() / N; math::Gemv( CblasNoTrans, M, N, 1.0f, B_data, A_data, 0.0f, Y_data, &context_); } else { const int M = B_dims[B_ndim - 1]; const int batch_size = B.numel() / (M * N); if (batch_size == 1) { math::Gemv( CblasTrans, N, M, 1.0f, B_data, A_data, 0.0f, Y_data, &context_); } else { math::GemmStridedBatched( CblasTrans, CblasNoTrans, batch_size, M, 1, N, 1.0f, B_data, M * N, A_data, 0, 0.0f, Y_data, M, &context_); } } return true; } if (B_ndim == 1) { const int N = B.numel(); if (trans_a_) { CAFFE_ENFORCE_EQ(A_dims[A_ndim - 2], N); } else { CAFFE_ENFORCE_EQ(A_dims[A_ndim - 1], N); } const std::vector Y_dims( A_dims.cbegin(), A_dims.cbegin() + A_ndim - 1); auto* Y = Output(0, Y_dims, at::dtype()); T* Y_data = Y->template mutable_data(); if (trans_a_) { const int M = A_dims[A_ndim - 1]; const int batch_size = A.numel() / (M * N); if (batch_size == 1) { math::Gemv( CblasTrans, N, M, 1.0f, A_data, B_data, 0.0f, Y_data, &context_); } else { math::GemmStridedBatched( CblasTrans, CblasNoTrans, batch_size, M, 1, N, 1.0f, A_data, M * N, B_data, 0, 0.0f, Y_data, M, &context_); } } else { const int M = A.numel() / N; math::Gemv( CblasNoTrans, M, N, 1.0f, A_data, B_data, 0.0f, Y_data, &context_); } return true; } const int M = trans_a_ ? A_dims[A_ndim - 1] : A_dims[A_ndim - 2]; const int K = trans_a_ ? A_dims[A_ndim - 2] : A_dims[A_ndim - 1]; if (trans_b_) { CAFFE_ENFORCE_EQ(B_dims[B_ndim - 1], K); } else { CAFFE_ENFORCE_EQ(B_dims[B_ndim - 2], K); } const int N = trans_b_ ? B_dims[B_ndim - 2] : B_dims[B_ndim - 1]; const int ndim = std::max(A_ndim, B_ndim); std::vector A_broadcast_dims(ndim); std::vector B_broadcast_dims(ndim); std::vector Y_broadcast_dims(ndim); math::utils::ComputeBroadcastBinaryOpDims( A_ndim - 2, A_dims.data(), B_ndim - 2, B_dims.data(), A_broadcast_dims.data(), B_broadcast_dims.data(), Y_broadcast_dims.data()); Y_broadcast_dims[ndim - 2] = M; Y_broadcast_dims[ndim - 1] = N; auto* Y = Output(0, Y_broadcast_dims, at::dtype()); T* Y_data = Y->template mutable_data(); const int batch_dim = ndim - 2; const bool is_broadcast_dims = !std::equal( A_broadcast_dims.cbegin(), A_broadcast_dims.cbegin() + batch_dim, B_broadcast_dims.cbegin()); if (is_broadcast_dims) { CAFFE_ENFORCE(broadcast_); } const std::int64_t A_batch_size = std::accumulate( A_broadcast_dims.cbegin(), A_broadcast_dims.cbegin() + batch_dim, 1LL, std::multiplies()); const std::int64_t B_batch_size = std::accumulate( B_broadcast_dims.cbegin(), B_broadcast_dims.cbegin() + batch_dim, 1LL, std::multiplies()); const std::int64_t Y_batch_size = std::accumulate( Y_broadcast_dims.cbegin(), Y_broadcast_dims.cbegin() + batch_dim, 1LL, std::multiplies()); if (Y_batch_size == 0) { return true; } if (A_batch_size == 1 && B_batch_size == 1) { math::Gemm( trans_a_ ? CblasTrans : CblasNoTrans, trans_b_ ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, A_data, B_data, 0.0f, Y_data, &context_); } else if (A_batch_size == 1) { if (M == 1 && trans_b_) { math::Gemv( CblasNoTrans, B_batch_size * N, K, 1.0f, B_data, A_data, 0.0f, Y_data, &context_); } else { math::GemmStridedBatched( trans_a_ ? CblasTrans : CblasNoTrans, trans_b_ ? CblasTrans : CblasNoTrans, Y_batch_size, M, N, K, 1.0f, A_data, 0, B_data, K * N, 0.0f, Y_data, M * N, &context_); } } else if (B_batch_size == 1) { if (!trans_a_) { math::Gemm( CblasNoTrans, trans_b_ ? CblasTrans : CblasNoTrans, A_batch_size * M, N, K, 1.0f, A_data, B_data, 0.0f, Y_data, &context_); } else { math::GemmStridedBatched( CblasTrans, trans_b_ ? CblasTrans : CblasNoTrans, Y_batch_size, M, N, K, 1.0f, A_data, M * K, B_data, 0, 0.0f, Y_data, M * N, &context_); } } else if (!is_broadcast_dims) { math::GemmStridedBatched( trans_a_ ? CblasTrans : CblasNoTrans, trans_b_ ? CblasTrans : CblasNoTrans, Y_batch_size, M, N, K, 1.0f, A_data, M * K, B_data, K * N, 0.0f, Y_data, M * N, &context_); } else { std::vector A_ptr(Y_batch_size); std::vector B_ptr(Y_batch_size); std::vector Y_ptr(Y_batch_size); std::vector index(batch_dim); for (std::int64_t i = 0; i < Y_batch_size; ++i) { const std::int64_t A_index = math::utils::GetIndexFromDims( batch_dim, A_broadcast_dims.data(), index.data()); const std::int64_t B_index = math::utils::GetIndexFromDims( batch_dim, B_broadcast_dims.data(), index.data()); A_ptr[i] = A_data + A_index * M * K; B_ptr[i] = B_data + B_index * K * N; Y_ptr[i] = Y_data + i * M * N; math::utils::IncreaseIndexInDims( batch_dim, Y_broadcast_dims.data(), index.data()); } math::GemmBatched( trans_a_ ? CblasTrans : CblasNoTrans, trans_b_ ? CblasTrans : CblasNoTrans, Y_batch_size, M, N, K, 1.0f, A_ptr.data(), B_ptr.data(), 0.0f, Y_ptr.data(), &context_); } return true; } private: const bool trans_a_; const bool trans_b_; const bool broadcast_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_