#pragma once #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/utils/eigen_utils.h" #include "caffe2/utils/math.h" namespace caffe2 { namespace detail { template void VariableLengthSequencePadding( int N, int B, int M, T* X, const int32_t* seqLengths, const T padValue, Context* /*context*/) { for (int j = 0; j < B; j++) { for (int i = seqLengths[j]; i < N; i++) { EigenVectorArrayMap(X + B * M * i + M * j, M).setConstant(padValue); } } } } // namespace detail template class VariableLengthSequencePaddingOp : public Operator { public: template explicit VariableLengthSequencePaddingOp(Args&&... args) : Operator(std::forward(args)...) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override { const auto N = Input(INPUT).size(0); const auto B = Input(INPUT).size(1); const auto M = Input(INPUT).size(2); auto X = Output(OUTPUT)->template mutable_data(); auto seqLengths = Input(SEQ_LENGTHS).template data(); detail::VariableLengthSequencePadding( N, B, M, X, seqLengths, 0, &context_); return true; } protected: INPUT_TAGS(INPUT, SEQ_LENGTHS); OUTPUT_TAGS(OUTPUT); }; } // namespace caffe2