#ifndef CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_ #define CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_ #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class LabelCrossEntropyOp final : public Operator { public: USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyOp); USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: static constexpr T kLOG_THRESHOLD() { return static_cast(1e-20); } // Input: X, label // Output: Y }; template class LabelCrossEntropyGradientOp final : public Operator { public: USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyGradientOp); USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: // Input: X, label, dY // Ouptut: dX. There is no gradient with respect to the label. static constexpr T kLOG_THRESHOLD() { return static_cast(1e-20); } }; // Hacky: turns a vector of probabilities into a 2-column matrix with // complimentary probabilities for binary classification template class MakeTwoClassOp final : public Operator { public: USE_SIMPLE_CTOR_DTOR(MakeTwoClassOp); USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: // Input: X // Output: Y = vstack(1-X, X) }; template class MakeTwoClassGradientOp final : public Operator { public: USE_SIMPLE_CTOR_DTOR(MakeTwoClassGradientOp); USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: // Input: dY // Ouptut: dX }; template class SigmoidCrossEntropyWithLogitsOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit SigmoidCrossEntropyWithLogitsOp(Args&&... args) : Operator(std::forward(args)...), log_D_trick_( this->template GetSingleArgument("log_D_trick", false)), unjoined_lr_loss_( this->template GetSingleArgument("unjoined_lr_loss", false)) { CAFFE_ENFORCE( !(log_D_trick_ && unjoined_lr_loss_), "log_D_trick_ and unjoined_lr_loss_ cannot be set as True simultaneously"); } bool RunOnDevice() override; protected: bool log_D_trick_; bool unjoined_lr_loss_; }; template class SigmoidCrossEntropyWithLogitsGradientOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit SigmoidCrossEntropyWithLogitsGradientOp(Args&&... args) : Operator(std::forward(args)...), log_D_trick_( this->template GetSingleArgument("log_D_trick", false)), unjoined_lr_loss_( this->template GetSingleArgument("unjoined_lr_loss", false)) { } bool RunOnDevice() override; protected: bool log_D_trick_; bool unjoined_lr_loss_; }; template class WeightedSigmoidCrossEntropyWithLogitsOp final : public Operator { public: USE_SIMPLE_CTOR_DTOR(WeightedSigmoidCrossEntropyWithLogitsOp); USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; }; template class WeightedSigmoidCrossEntropyWithLogitsGradientOp final : public Operator { public: USE_SIMPLE_CTOR_DTOR(WeightedSigmoidCrossEntropyWithLogitsGradientOp); USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; }; template class CAFFE2_API CrossEntropyOp final : public Operator { public: USE_SIMPLE_CTOR_DTOR(CrossEntropyOp); USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: // Input: X, label // Output: Y static constexpr T kLOG_THRESHOLD() { return static_cast(1e-20); } }; template class CAFFE2_API CrossEntropyGradientOp final : public Operator { public: USE_SIMPLE_CTOR_DTOR(CrossEntropyGradientOp); USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; protected: // Input: X, label, dY // Ouptut: dX. There is no gradient with respect to the label. static constexpr T kLOG_THRESHOLD() { return static_cast(1e-20); } }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_