#pragma once #include "rebatching_queue.h" namespace caffe2 { using RebatchingQueuePtr = std::unique_ptr; class CreateRebatchingQueueOp : public Operator { public: CreateRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws) {} bool RunOnDevice() override { *OperatorBase::Output(0) = RebatchingQueuePtr(new RebatchingQueue( OperatorBase::GetSingleArgument("capacity", 1), OperatorBase::GetSingleArgument("num_blobs", 1))); return true; } }; class EnqueueRebatchingQueueOp : public Operator { public: EnqueueRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), enqueueBatch_( OperatorBase::GetSingleArgument("enqueue_batch", false)) {} bool RunOnDevice() override { auto& queue = Inputs()[0]->template Get(); CHECK(queue); CAFFE_ENFORCE_EQ(InputSize(), queue->numBlobs() + 1); std::vector inputTensors; inputTensors.reserve(InputSize() - 1); for (int i = 1; i < InputSize(); ++i) { inputTensors.push_back(&Input(i)); } return enqueueBatch_ ? queue->enqueueMany(context_, inputTensors) : queue->enqueueOne(context_, inputTensors); } private: const bool enqueueBatch_; }; class DequeueRebatchingQueueOp : public Operator { public: DequeueRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), numElements_(OperatorBase::GetSingleArgument("num_elements", 1)) {} bool RunOnDevice() override { auto& queue = Inputs()[0]->template Get(); CHECK(queue); std::vector outputTensors; outputTensors.reserve(OutputSize()); for (int i = 0; i < OutputSize(); ++i) { outputTensors.push_back(Output(i)); } return queue->dequeue(context_, numElements_, outputTensors); } private: int numElements_; }; class CloseRebatchingQueueOp : public Operator { public: CloseRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws) {} bool RunOnDevice() override { CAFFE_ENFORCE_EQ(InputSize(), 1); auto& queue = Inputs()[0]->template Get(); CAFFE_ENFORCE(queue); queue->close(); return true; } }; } // caffe2