#pragma once #include #include #include "caffe2/core/db.h" #include "caffe2/core/logging.h" #include "caffe2/core/stats.h" #include "caffe2/queue/blobs_queue.h" namespace caffe2 { namespace db { namespace { const std::string& GetStringFromBlob(Blob* blob) { if (blob->template IsType()) { return blob->template Get(); } else if (blob->template IsType()) { return *blob->template Get().template data(); } else { CAFFE_THROW("Unsupported Blob type"); } } } class BlobsQueueDBCursor : public Cursor { public: explicit BlobsQueueDBCursor( std::shared_ptr queue, int key_blob_index, int value_blob_index, float timeout_secs) : queue_(queue), key_blob_index_(key_blob_index), value_blob_index_(value_blob_index), timeout_secs_(timeout_secs), inited_(false), valid_(false) { LOG(INFO) << "BlobsQueueDBCursor constructed"; CAFFE_ENFORCE(queue_ != nullptr, "queue is null"); CAFFE_ENFORCE(value_blob_index_ >= 0, "value_blob_index < 0"); } virtual ~BlobsQueueDBCursor() {} void Seek(const string& /* unused */) override { CAFFE_THROW("Seek is not supported."); } bool SupportsSeek() override { return false; } void SeekToFirst() override { // not applicable } void Next() override { unique_ptr blob = make_unique(); vector blob_vector{blob.get()}; auto success = queue_->blockingRead(blob_vector, timeout_secs_); if (!success) { LOG(ERROR) << "Timed out reading from BlobsQueue or it is closed"; valid_ = false; return; } if (key_blob_index_ >= 0) { key_ = GetStringFromBlob(blob_vector[key_blob_index_]); } value_ = GetStringFromBlob(blob_vector[value_blob_index_]); valid_ = true; } string key() override { if (!inited_) { Next(); inited_ = true; } return key_; } string value() override { if (!inited_) { Next(); inited_ = true; } return value_; } bool Valid() override { return valid_; } private: std::shared_ptr queue_; int key_blob_index_; int value_blob_index_; float timeout_secs_; bool inited_; string key_; string value_; bool valid_; }; class BlobsQueueDB : public DB { public: BlobsQueueDB( const string& source, Mode mode, std::shared_ptr queue, int key_blob_index = -1, int value_blob_index = 0, float timeout_secs = 0.0) : DB(source, mode), queue_(queue), key_blob_index_(key_blob_index), value_blob_index_(value_blob_index), timeout_secs_(timeout_secs) { LOG(INFO) << "BlobsQueueDB constructed"; } virtual ~BlobsQueueDB() { Close(); } void Close() override {} unique_ptr NewCursor() override { return make_unique( queue_, key_blob_index_, value_blob_index_, timeout_secs_); } unique_ptr NewTransaction() override { CAFFE_THROW("Not implemented."); } private: std::shared_ptr queue_; int key_blob_index_; int value_blob_index_; float timeout_secs_; }; } // namespace db } // namespace caffe2