#ifndef CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_ #define CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_ #include #include "caffe2/operators/elementwise_ops.h" #include "caffe2/utils/math.h" namespace caffe2 { namespace { template void ComputeMulGradient( const int ndim, const int* A_dims, const int* B_dims, const int* C_dims, const TGrad* dC, const TIn* A, const TIn* B, TGrad* dA, TGrad* dB, CPUContext* context) { const int A_size = std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies()); const int B_size = std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies()); const int C_size = std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies()); math::Set(A_size, TGrad(0), dA, context); math::Set(B_size, TGrad(0), dB, context); std::vector index(ndim, 0); for (int C_index = 0; C_index < C_size; ++C_index) { const int A_index = math::utils::GetIndexFromDims(ndim, A_dims, index.data()); const int B_index = math::utils::GetIndexFromDims(ndim, B_dims, index.data()); dA[A_index] += dC[C_index] * B[B_index]; dB[B_index] += dC[C_index] * A[A_index]; math::utils::IncreaseIndexInDims(ndim, C_dims, index.data()); } } } // namespace template struct MulFunctor { template bool Forward( const std::vector& A_dims, const std::vector& B_dims, const TIn* A, const TIn* B, TOut* C, Context* context) const { math::Mul( A_dims.size(), A_dims.data(), B_dims.size(), B_dims.data(), A, B, C, context); return true; } template bool Backward( const std::vector& A_dims, const std::vector& B_dims, const TGrad* dC_data, const TIn* A_data, const TIn* B_data, const TOut* C_data, TGrad* dA_data, TGrad* dB_data, Context* context) const; }; template <> template bool MulFunctor::Backward( const std::vector& A_dims, const std::vector& B_dims, const TGrad* dC, const TIn* A, const TIn* B, const TOut* /* C */, TGrad* dA, TGrad* dB, CPUContext* context) const { if (A_dims == B_dims) { const int size = std::accumulate( A_dims.cbegin(), A_dims.cend(), 1, std::multiplies()); math::Mul(size, dC, B, dA, context); math::Mul(size, dC, A, dB, context); return true; } const int ndim = std::max(A_dims.size(), B_dims.size()); std::vector A_broadcast_dims(ndim); std::vector B_broadcast_dims(ndim); std::vector C_broadcast_dims(ndim); math::utils::ComputeBroadcastBinaryOpDims( A_dims.size(), A_dims.data(), B_dims.size(), B_dims.data(), A_broadcast_dims.data(), B_broadcast_dims.data(), C_broadcast_dims.data()); ComputeMulGradient( ndim, A_broadcast_dims.data(), B_broadcast_dims.data(), C_broadcast_dims.data(), dC, A, B, dA, dB, context); return true; } } // namespace caffe2 #endif // CAFFE2_OPERATORS_ELEMENTWISE_MUL_OP_H_