#pragma once #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/perfkernels/embedding_lookup.h" namespace caffe2 { // A templated class that implements SparseLengths[Sum,WeightedSum,Mean]. template < typename T, // output type class InputTypes, // supported input types, such as TensorTypes bool USE_WEIGHT = 0, // Whether it is SparseLengthsWeightedSum bool USE_MEAN = 0, // Whether this is SparseLengthsMean bool USE_POSITIONAL_WEIGHT = 0 // USE_WEIGHT = 1 and USE_POSITIONAL_WEIGHT = 1 // -> SparseLengthsPositionalWeightedSum > class CPUSparseLengthsReductionOp : public Operator { public: USE_OPERATOR_FUNCTIONS(CPUContext); template explicit CPUSparseLengthsReductionOp(Args&&... args) : Operator(std::forward(args)...) { static_assert( !(USE_WEIGHT & USE_MEAN), "Cannot both specify weight and mean."); } ~CPUSparseLengthsReductionOp() {} // Currently, we support float and at::Half inputs for input data type, and // int32_t and int64_t for the index type. bool RunOnDevice() override { return DispatchHelper::call(this, Input(DATA)); } template bool DoRunWithType() { return DispatchHelper, InputType>::call( this, Input(INDICES)); } template bool DoRunWithType2() { auto& dataInput = Input(DATA); auto& indicesInput = Input(INDICES); auto& lengthsInput = Input(LENGTHS); CAFFE_ENFORCE_EQ(1, indicesInput.dim(), "INDICES must be a vector"); CAFFE_ENFORCE_EQ(1, lengthsInput.dim(), "LENGTHS must be a vector"); const int64_t N = dataInput.size(0); const int D = dataInput.size_from_dim(1); const int64_t M = lengthsInput.size(0); const int64_t indices_size = indicesInput.numel(); auto shape = dataInput.sizes().vec(); shape[0] = M; auto* output = Output(0, shape, at::dtype()); T* out_data = output->template mutable_data(); const InputType* in_data = dataInput.template data(); const IndexType* indices = indicesInput.template data(); const int* lengths = lengthsInput.template data(); const T* in_weight = nullptr; if (USE_WEIGHT) { // static if auto& weightInput = Input(WEIGHT); CAFFE_ENFORCE_EQ(1, weightInput.dim(), "WEIGHT must be a vector"); if (!USE_POSITIONAL_WEIGHT) { CAFFE_ENFORCE_EQ( weightInput.numel(), indices_size, "Weight should have the same length as indices."); } in_weight = weightInput.template data(); } // delegate work to perfkernel that branches based on architecture EmbeddingLookup( D, M, indices_size, N, in_data, indices, lengths, in_weight, nullptr, // scale_bias field is only used in SparseLengths8BitsRowwiseOp USE_MEAN, out_data); return true; } enum { DATA = 0, // Data input. WEIGHT = 1, // Weight input used in SparseLengthsWeightedSum INDICES = 1 + USE_WEIGHT, // 1 in SparseLengths[Sum,Mean] and // 2 in SparseLengthsWeightedSum LENGTHS = 2 + USE_WEIGHT, // 2 in SparseLengths[Sum, Mean], // 3 in SparseLengthsWeightedSum }; }; } // namespace caffe2