#pragma once #include #include "blobs_queue.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class CreateBlobsQueueOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; CreateBlobsQueueOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), ws_(ws), name(operator_def.output().Get(0)) {} bool RunOnDevice() override { const auto capacity = GetSingleArgument("capacity", 1); const auto numBlobs = GetSingleArgument("num_blobs", 1); const auto enforceUniqueName = GetSingleArgument("enforce_unique_name", false); const auto fieldNames = OperatorBase::template GetRepeatedArgument("field_names"); CAFFE_ENFORCE_EQ(this->OutputSize(), 1); auto queuePtr = Operator::Outputs()[0] ->template GetMutable>(); CAFFE_ENFORCE(queuePtr); *queuePtr = std::make_shared( ws_, name, capacity, numBlobs, enforceUniqueName, fieldNames); return true; } private: Workspace* ws_{nullptr}; const std::string name; }; template class EnqueueBlobsOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; using Operator::Operator; bool RunOnDevice() override { CAFFE_ENFORCE(InputSize() > 1); auto queue = Operator::Inputs()[0] ->template Get>(); CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs()); return queue->blockingWrite(this->Outputs()); } private: }; template class DequeueBlobsOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; DequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws) { timeout_secs_ = OperatorBase::GetSingleArgument("timeout_secs", 0); } bool RunOnDevice() override { CAFFE_ENFORCE(InputSize() == 1); auto queue = OperatorBase::Inputs()[0]->template Get>(); CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs()); return queue->blockingRead(this->Outputs(), timeout_secs_); } private: float timeout_secs_; }; template class CloseBlobsQueueOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; using Operator::Operator; bool RunOnDevice() override { CAFFE_ENFORCE_EQ(InputSize(), 1); auto queue = OperatorBase::Inputs()[0]->template Get>(); CAFFE_ENFORCE(queue); queue->close(); return true; } private: }; template class SafeEnqueueBlobsOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; using Operator::Operator; bool RunOnDevice() override { auto queue = Operator::Inputs()[0] ->template Get>(); CAFFE_ENFORCE(queue); auto size = queue->getNumBlobs(); CAFFE_ENFORCE( OutputSize() == size + 1, "Expected " + c10::to_string(size + 1) + ", " + " got: " + c10::to_string(size)); bool status = queue->blockingWrite(this->Outputs()); Output(size)->Resize(); math::Set( 1, !status, Output(size)->template mutable_data(), &context_); return true; } }; template class SafeDequeueBlobsOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; using Operator::Operator; SafeDequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), numRecords_(OperatorBase::GetSingleArgument("num_records", 1)) { CAFFE_ENFORCE_GT(numRecords_, 0); } bool dequeueMany(std::shared_ptr& queue) { auto size = queue->getNumBlobs(); if (blobs_.size() != size) { blobs_.resize(size); blobPtrs_.resize(size); for (int col = 0; col < size; ++col) { blobPtrs_.at(col) = &blobs_.at(col); } } const int kTensorGrowthPct = 40; for (int i = 0; i < numRecords_; ++i) { if (!queue->blockingRead(blobPtrs_)) { // if we read at least one record, status is still true return i > 0; } for (int col = 0; col < size; ++col) { auto* out = this->Output(col); const auto& in = blobPtrs_.at(col)->template Get(); if (i == 0) { out->CopyFrom(in); } else { auto oldSize = out->numel(); CAFFE_ENFORCE( in.dim() > 0, "Empty tensor to dequeue at column ", col, " within ", size, " total columns"); out->Extend(in.sizes()[0], kTensorGrowthPct); auto* dst = (char*)out->raw_mutable_data() + oldSize * in.dtype().itemsize(); context_.template CopyItems( in.meta(), in.numel(), in.raw_data(), dst); } } } return true; } bool dequeueOne(std::shared_ptr& queue) { return queue->blockingRead(this->Outputs()); } bool RunOnDevice() override { CAFFE_ENFORCE(InputSize() == 1); auto queue = Operator::Inputs()[0] ->template Get>(); CAFFE_ENFORCE(queue); auto size = queue->getNumBlobs(); CAFFE_ENFORCE_EQ(OutputSize(), size + 1); bool status = numRecords_ > 1 ? dequeueMany(queue) : dequeueOne(queue); Output(size)->Resize(); math::Set( 1, !status, Output(size)->template mutable_data(), &context_); return true; } private: int numRecords_; std::vector blobs_; std::vector blobPtrs_; }; template class WeightedSampleDequeueBlobsOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; WeightedSampleDequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), table_idx_blob_( OperatorBase::GetSingleArgument("table_idx_blob", -1)) { CAFFE_ENFORCE_LT(table_idx_blob_, OutputSize() - 1); vector weights = OperatorBase::GetRepeatedArgument("weights"); if (weights.empty()) { weights.resize(InputSize(), 1.0f); } CAFFE_ENFORCE_EQ(InputSize(), weights.size()); float sum = accumulate(weights.begin(), weights.end(), 0.0f); CAFFE_ENFORCE(sum > 0.0f, "Sum of weights must be positive"); cumProbs_.resize(weights.size()); for (int i = 0; i < weights.size(); i++) { cumProbs_[i] = weights[i] / sum; CAFFE_ENFORCE_GE( cumProbs_[i], 0.0f, "Each probability must be non-negative"); } std::partial_sum(cumProbs_.begin(), cumProbs_.end(), cumProbs_.begin()); // Put last value to be 1.0001 to avoid numerical issues. cumProbs_.back() = 1.0001f; LOG(INFO) << "Dequeue weights: " << weights; LOG(INFO) << "cumProbs: " << cumProbs_; } bool RunOnDevice() override { float r; math::RandUniform(1, 0.0f, 1.0f, &r, &context_); auto lb = lower_bound(cumProbs_.begin(), cumProbs_.end(), r); CAFFE_ENFORCE(lb != cumProbs_.end(), "Cannot find ", r, " in cumProbs_."); const int32_t idx = lb - cumProbs_.begin(); auto queue = Operator::Inputs()[idx] ->template Get>(); CAFFE_ENFORCE(queue); auto size = queue->getNumBlobs(); CAFFE_ENFORCE_EQ(OutputSize(), size + 1); bool status = queue->blockingRead(this->Outputs()); if (table_idx_blob_ >= 0) { auto* table_idx_blob_out = Output(table_idx_blob_, {1}, at::dtype()); int32_t* data = table_idx_blob_out->template mutable_data(); data[0] = idx; } Output(size)->Resize(); math::Set( 1, !status, Output(size)->template mutable_data(), &context_); return true; } private: vector cumProbs_; int table_idx_blob_; }; } // namespace caffe2