// Copyright 2004-present Facebook. All Rights Reserved. #pragma once #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class LambdaRankNdcgOp final : public Operator { public: template explicit LambdaRankNdcgOp(Args&&... args) : Operator(std::forward(args)...), use_ndcg_as_loss_( this->template GetSingleArgument("use_ndcg_as_loss", false)), use_idcg_normalization_(this->template GetSingleArgument( "use_idcg_normalization", true)), use_exp_gain_( this->template GetSingleArgument("use_exp_gain", true)) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; private: INPUT_TAGS(PRED, REL, SESSION_LENS); OUTPUT_TAGS(LOSS, DPRED); void ResizeInvLogITensor(int); void ComputeDiscounts(int*, int); float LambdaRankNdcgSession( int start_index, int end_index, const Tensor& y, const Tensor& r, Tensor** dy); bool use_ndcg_as_loss_; bool use_idcg_normalization_; bool use_exp_gain_; Tensor gain_; Tensor discount_; Tensor rank_idx_; Tensor ideal_idx_; Tensor lambda_; Tensor inv_log_i_; }; template class LambdaRankNdcgGradientOp final : public Operator { public: USE_SIMPLE_CTOR_DTOR(LambdaRankNdcgGradientOp); USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; private: INPUT_TAGS(Y, SESSION_LENS, DY_CACHE, DLOSS); OUTPUT_TAGS(DY); }; } // namespace caffe2