#ifndef CAFFE2_OPERATORS_ELEMENTWISE_ADD_OP_H_ #define CAFFE2_OPERATORS_ELEMENTWISE_ADD_OP_H_ #include #include #include #include "caffe2/operators/elementwise_ops.h" #include "caffe2/operators/elementwise_ops_utils.h" #include "caffe2/utils/math.h" namespace caffe2 { template struct AddFunctor { template bool Forward( const std::vector& A_dims, const std::vector& B_dims, const TIn* A, const TIn* B, TOut* C, Context* context) const { math::Add( A_dims.size(), A_dims.data(), B_dims.size(), B_dims.data(), A, B, C, context); return true; } template bool Backward( const std::vector& A_dims, const std::vector& B_dims, const TGrad* dC, const TIn* /* A */, const TIn* /* B */, const TOut* /* C */, TGrad* dA, TGrad* dB, Context* context) const { const std::vector C_dims = elementwise_ops_utils::ComputeBinaryBroadcastForwardDims( A_dims, B_dims); std::vector A_back_dims; std::vector B_back_dims; elementwise_ops_utils::ComputeBinaryBroadcastBackwardDims( A_dims, B_dims, &A_back_dims, &B_back_dims); math::ReduceSum( C_dims.size(), C_dims.data(), A_back_dims.data(), TGrad(1), dC, dA, context); math::ReduceSum( C_dims.size(), C_dims.data(), B_back_dims.data(), TGrad(1), dC, dB, context); return true; } }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_ELEMENTWISE_ADD_OP_H_