#ifndef CAFFE2_OPERATORS_ELEMENTWISE_OPS_H_ #define CAFFE2_OPERATORS_ELEMENTWISE_OPS_H_ #include #include #include #include #include "caffe2/core/common_omp.h" #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" #include "caffe2/operators/elementwise_ops_utils.h" #include "caffe2/utils/eigen_utils.h" #include "caffe2/utils/math.h" namespace caffe2 { using NumericTypes = TensorTypes; using IntTypes = TensorTypes; using BoolTypes = TensorTypes; using IntBoolTypes = TensorTypes; // discrete types struct SameTypeAsInput { template using type = T; }; template struct FixedType { template using type = R; }; template < typename InputTypes, class Context, class Functor, class OutputTypeMap = SameTypeAsInput> class UnaryElementwiseWithArgsOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit UnaryElementwiseWithArgsOp(Args&&... args) : Operator(std::forward(args)...), functor_(*this) {} bool RunOnDevice() override { return DispatchHelper::call(this, Input(0)); } template bool DoRunWithType() { const auto& X = Input(0); auto* Y = Output( 0, X.sizes(), at::dtype>()); return functor_( X.numel(), X.template data(), Y->template mutable_data>(), &context_); } private: Functor functor_; }; // UnaryFunctorWithDefaultCtor is a functor that can be used as the functor of // an UnaryElementwiseWithArgsOp. It simply forwards the operator() call into // another functor that doesn't accept arguments in its constructor. template struct UnaryFunctorWithDefaultCtor { explicit UnaryFunctorWithDefaultCtor(OperatorBase& /* op */) {} template bool operator()(const int size, const TIn* X, TOut* Y, Context* context) const { return functor(size, X, Y, context); } Functor functor{}; }; // UnaryElementwiseOp is a wrapper around UnaryElementwiseWithArgsOp, with the // difference that it takes a functor with default constructor, e.g. that does // not need to take into consideration any arguments during operator creation. template < typename InputTypes, class Context, class Functor, class OutputTypeMap = SameTypeAsInput> using UnaryElementwiseOp = UnaryElementwiseWithArgsOp< InputTypes, Context, UnaryFunctorWithDefaultCtor, OutputTypeMap>; template < typename InputTypes, class Context, class Functor, class OutputTypeMap = SameTypeAsInput> class BinaryElementwiseWithArgsOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit BinaryElementwiseWithArgsOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(bool, "broadcast", legacy_broadcast_, false), OP_SINGLE_ARG(int, "axis", axis_, -1), OP_SINGLE_ARG(string, "axis_str", axis_str_, string("")), OP_SINGLE_ARG(string, "order", order_, "NCHW"), functor_(*this) { if (legacy_broadcast_) { if (axis_ != -1) { // Get axis from an explicit axis argument. CAFFE_ENFORCE_EQ( axis_str_.size(), 0, "Args axis and axis_str cannot be used simultaneously."); } else if (axis_str_.size()) { // Get the axis index semantically. CAFFE_ENFORCE_EQ( axis_str_.size(), 1, "Unsupported axis string", axis_str_); const size_t semantic_axis_ = order_.find(axis_str_); CAFFE_ENFORCE_NE( semantic_axis_, string::npos, "Unrecognizable axis string ", axis_str_, " from order string ", order_); axis_ = semantic_axis_; } else { CAFFE_ENFORCE( axis_ == -1 && axis_str_.empty(), "Do not specify axis or axis_str if broadcast is not enabled."); } } } bool RunOnDevice() override { return DispatchHelper::call(this, Input(0)); } template bool DoRunWithType() { const auto& A = Input(0); const auto& B = Input(1); const T* A_data = A.template data(); const T* B_data = B.template data(); std::vector A_dims; std::vector B_dims; std::vector C_dims; if (legacy_broadcast_) { CAFFE_ENFORCE( !IsInputOutputAlias(1, 0), "In-place is allowed only with the first tensor when " "legacy-broadcasting"); C_dims = A.sizes().vec(); if (B.numel() == 1) { A_dims = {static_cast(A.numel())}; B_dims = {1}; } else { size_t pre, n, post; std::tie(pre, n, post) = elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_); A_dims = { static_cast(pre), static_cast(n), static_cast(post)}; B_dims = {static_cast(n), 1}; } } else { std::copy( A.sizes().cbegin(), A.sizes().cend(), std::back_inserter(A_dims)); std::copy( B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims)); // TODO: change the types to vector auto C_dims_int = elementwise_ops_utils::ComputeBinaryBroadcastForwardDims( A_dims, B_dims); std::copy( C_dims_int.cbegin(), C_dims_int.cend(), std::back_inserter(C_dims)); if (IsInputOutputAlias(0, 0)) { CAFFE_ENFORCE_EQ(C_dims_int, A_dims); } else if (IsInputOutputAlias(1, 0)) { CAFFE_ENFORCE_EQ(C_dims_int, B_dims); } } auto* C = Output( 0, C_dims, at::dtype>()); auto* C_data = C->template mutable_data>(); return functor_.Forward(A_dims, B_dims, A_data, B_data, C_data, &context_); } private: const bool legacy_broadcast_; int axis_; const std::string axis_str_; const std::string order_; Functor functor_; }; template < typename InputTypes, class Context, class Functor, class OutputTypeMap = SameTypeAsInput, class GradientTypeMap = SameTypeAsInput> class BinaryElementwiseWithArgsGradientOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit BinaryElementwiseWithArgsGradientOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(bool, "broadcast", legacy_broadcast_, false), OP_SINGLE_ARG(int, "axis", axis_, -1), OP_SINGLE_ARG(string, "axis_str", axis_str_, ""), OP_SINGLE_ARG(string, "order", order_, "NCHW"), functor_(*this) { if (legacy_broadcast_) { if (axis_ != -1) { // Get axis from an explicit axis argument. CAFFE_ENFORCE_EQ( axis_str_.size(), 0, "Args axis and axis_str cannot be used simultaneously."); } else if (axis_str_.size()) { // Get the axis index semantically. CAFFE_ENFORCE_EQ( axis_str_.size(), 1, "Unsupported axis string", axis_str_); const size_t semantic_axis_ = order_.find(axis_str_); CAFFE_ENFORCE_NE( semantic_axis_, string::npos, "Unrecognizable axis string ", axis_str_, " from order string ", order_); axis_ = semantic_axis_; } else { CAFFE_ENFORCE( axis_ == -1 && axis_str_.empty(), "Do not specify axis or axis_str if broadcast is not enabled."); } } } bool RunOnDevice() override { return DispatchHelper::call(this, Input(1)); } template bool DoRunWithType() { const auto& dC = Input(0); const auto& A = Input(1); const auto& B = Input(2); vector A_dims; vector B_dims; if (legacy_broadcast_) { if (B.numel() == 1) { A_dims = {static_cast(A.numel())}; B_dims = {1}; } else { size_t pre, n, post; std::tie(pre, n, post) = elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_); A_dims = { static_cast(pre), static_cast(n), static_cast(post)}; B_dims = {static_cast(n), 1}; } } else { std::copy( A.sizes().cbegin(), A.sizes().cend(), std::back_inserter(A_dims)); std::copy( B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims)); } const typename OutputTypeMap::template type* C_data = nullptr; if (InputSize() == 4) { const auto& C = Input(3); C_data = C.template data>(); } const auto* dC_data = dC.template data>(); const T* A_data = A.template data(); const T* B_data = B.template data(); auto* dA = Output( 0, A.sizes(), at::dtype>()); auto* dB = Output( 1, B.sizes(), at::dtype>()); auto* dA_data = dA->template mutable_data>(); auto* dB_data = dB->template mutable_data>(); return functor_.Backward( A_dims, B_dims, dC_data, A_data, B_data, C_data, dA_data, dB_data, &context_); } private: const bool legacy_broadcast_; int axis_; const std::string axis_str_; const std::string order_; Functor functor_; }; template struct BinaryFunctorWithDefaultCtor { explicit BinaryFunctorWithDefaultCtor(OperatorBase& /* op */) {} template bool Forward( const std::vector& A_dims, const std::vector& B_dims, const TIn* A_data, const TIn* B_data, TOut* C_data, Context* context) const { return functor.Forward(A_dims, B_dims, A_data, B_data, C_data, context); } template bool Backward( const std::vector& A_dims, const std::vector& B_dims, const TGrad* dC_data, const TIn* A_data, const TIn* B_data, const TOut* C_data, TGrad* dA_data, TGrad* dB_data, Context* context) const { return functor.Backward( A_dims, B_dims, dC_data, A_data, B_data, C_data, dA_data, dB_data, context); } Functor functor{}; }; // BinaryElementwiseOp is a wrapper around BinaryElementwiseWithArgsOp, with the // difference that it takes a functor with default constructor, e.g. that does // not need to take into consideration any arguments during operator creation. template < typename InputTypes, class Context, class Functor, class TypeMap = SameTypeAsInput> using BinaryElementwiseOp = BinaryElementwiseWithArgsOp< InputTypes, Context, BinaryFunctorWithDefaultCtor, TypeMap>; // BinaryElementwiseGradientOp is a wrapper around // BinaryElementwiseGradientWithArgsOp, with the difference that it takes a // functor with default constructor, e.g. that does not need to take into // consideration any arguments during operator creation. template < typename InputTypes, class Context, class Functor, class OutputTypeMap = SameTypeAsInput, class GradientTypeMap = SameTypeAsInput> using BinaryElementwiseGradientOp = BinaryElementwiseWithArgsGradientOp< InputTypes, Context, BinaryFunctorWithDefaultCtor, OutputTypeMap, GradientTypeMap>; // Forward-only Unary Functors. template struct NotFunctor { bool operator()(const int N, const bool* X, bool* Y, Context* context) const { math::Not(N, X, Y, context); return true; } }; template struct SignFunctor { template bool operator()(const int N, const T* X, T* Y, Context* context) const { math::Sign(N, X, Y, context); return true; } }; // Forward-only Binary Functors. #define C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(FunctorName) \ template \ struct FunctorName##Functor { \ template \ bool Forward( \ const std::vector& A_dims, \ const std::vector& B_dims, \ const TIn* A, \ const TIn* B, \ TOut* C, \ Context* context) const { \ math::FunctorName( \ A_dims.size(), \ A_dims.data(), \ B_dims.size(), \ B_dims.data(), \ A, \ B, \ C, \ context); \ return true; \ } \ }; // Compare functors. C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(EQ); C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(NE); C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(LT); C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(LE); C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(GT); C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(GE); // Logical functors. C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(And); C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(Or); C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(Xor); // Bitwise functors. C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(BitwiseAnd); C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(BitwiseOr); C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(BitwiseXor); #undef C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR namespace SRLHelper { template void sum2one(const T* a, T* y, size_t n); template void RunWithBroadcastFront(const T* a, T* y, size_t pre, size_t n, CPUContext*); template void RunWithBroadcastBack(const T* a, T* y, size_t post, size_t n, CPUContext*); template void RunWithBroadcast2( const T* a, T* y, size_t pre, size_t n, size_t post, CPUContext*); } // namespace SRLHelper // Sum reduction operator that is used for computing the gradient in cases // where the forward op is in broadcast mode. template class SumReduceLikeOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit SumReduceLikeOp(Args&&... args) : Operator(std::forward(args)...), OP_SINGLE_ARG(int, "axis", axis_, -1), OP_SINGLE_ARG(string, "axis_str", axis_str_, ""), OP_SINGLE_ARG(string, "order", order_, "NCHW") { if (axis_ != -1) { // Get axis from an explicit axis argument. CAFFE_ENFORCE_EQ( axis_str_.size(), 0, "Args axis and axis_str cannot be used simultaneously."); } else if (axis_str_.size()) { // Get the axis index semantically. CAFFE_ENFORCE_EQ( axis_str_.size(), 1, "Unsupported axis string", axis_str_); size_t semantic_axis = order_.find(axis_str_); CAFFE_ENFORCE_NE( semantic_axis, string::npos, "Unrecognizable axis string ", axis_str_, " from order string ", order_); axis_ = semantic_axis; } } bool RunOnDevice() override { return DispatchHelper>::call(this, Input(0)); } template bool DoRunWithType(); private: int axis_; string axis_str_; string order_; Tensor ones_{Context::GetDeviceType()}; Tensor sum_buffer_{Context::GetDeviceType()}; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_ELEMENTWISE_OPS_H_