#ifndef CAFFE2_CORE_DB_H_ #define CAFFE2_CORE_DB_H_ #include #include "c10/util/Registry.h" #include "caffe2/core/blob_serialization.h" #include "caffe2/proto/caffe2_pb.h" namespace caffe2 { namespace db { /** * The mode of the database, whether we are doing a read, write, or creating * a new database. */ enum Mode { READ, WRITE, NEW }; /** * An abstract class for the cursor of the database while reading. */ class CAFFE2_API Cursor { public: Cursor() { } virtual ~Cursor() { } /** * Seek to a specific key (or if the key does not exist, seek to the * immediate next). This is optional for dbs, and in default, SupportsSeek() * returns false meaning that the db cursor does not support it. */ virtual void Seek(const string& key) = 0; virtual bool SupportsSeek() { return false; } /** * Seek to the first key in the database. */ virtual void SeekToFirst() = 0; /** * Go to the next location in the database. */ virtual void Next() = 0; /** * Returns the current key. */ virtual string key() = 0; /** * Returns the current value. */ virtual string value() = 0; /** * Returns whether the current location is valid - for example, if we have * reached the end of the database, return false. */ virtual bool Valid() = 0; C10_DISABLE_COPY_AND_ASSIGN(Cursor); }; /** * An abstract class for the current database transaction while writing. */ class CAFFE2_API Transaction { public: Transaction() { } virtual ~Transaction() { } /** * Puts the key value pair to the database. */ virtual void Put(const string& key, const string& value) = 0; /** * Commits the current writes. */ virtual void Commit() = 0; C10_DISABLE_COPY_AND_ASSIGN(Transaction); }; /** * An abstract class for accessing a database of key-value pairs. */ class CAFFE2_API DB { public: DB(const string& /*source*/, Mode mode) : mode_(mode) {} virtual ~DB() { } /** * Closes the database. */ virtual void Close() = 0; /** * Returns a cursor to read the database. The caller takes the ownership of * the pointer. */ virtual std::unique_ptr NewCursor() = 0; /** * Returns a transaction to write data to the database. The caller takes the * ownership of the pointer. */ virtual std::unique_ptr NewTransaction() = 0; protected: Mode mode_; C10_DISABLE_COPY_AND_ASSIGN(DB); }; // Database classes are registered by their names so we can do optional // dependencies. C10_DECLARE_REGISTRY(Caffe2DBRegistry, DB, const string&, Mode); #define REGISTER_CAFFE2_DB(name, ...) \ C10_REGISTER_CLASS(Caffe2DBRegistry, name, __VA_ARGS__) /** * Returns a database object of the given database type, source and mode. The * caller takes the ownership of the pointer. If the database type is not * supported, a nullptr is returned. The caller is responsible for examining the * validity of the pointer. */ inline unique_ptr CreateDB( const string& db_type, const string& source, Mode mode) { auto result = Caffe2DBRegistry()->Create(db_type, source, mode); VLOG(1) << ((!result) ? "not found db " : "found db ") << db_type; return result; } /** * Returns whether or not a database exists given the database type and path. */ inline bool DBExists(const string& db_type, const string& full_db_name) { // Warning! We assume that creating a DB throws an exception if the DB // does not exist. If the DB constructor does not follow this design // pattern, // the returned output (the existence tensor) can be wrong. try { std::unique_ptr db( caffe2::db::CreateDB(db_type, full_db_name, caffe2::db::READ)); return true; } catch (...) { return false; } } /** * A reader wrapper for DB that also allows us to serialize it. */ class CAFFE2_API DBReader { public: friend class DBReaderSerializer; DBReader() {} DBReader( const string& db_type, const string& source, const int32_t num_shards = 1, const int32_t shard_id = 0) { Open(db_type, source, num_shards, shard_id); } explicit DBReader(const DBReaderProto& proto) { Open(proto.db_type(), proto.source()); if (proto.has_key()) { CAFFE_ENFORCE(cursor_->SupportsSeek(), "Encountering a proto that needs seeking but the db type " "does not support it."); cursor_->Seek(proto.key()); } num_shards_ = 1; shard_id_ = 0; } explicit DBReader(std::unique_ptr db) : db_type_(""), source_(""), db_(std::move(db)) { CAFFE_ENFORCE(db_.get(), "Passed null db"); cursor_ = db_->NewCursor(); } void Open( const string& db_type, const string& source, const int32_t num_shards = 1, const int32_t shard_id = 0) { // Note(jiayq): resetting is needed when we re-open e.g. leveldb where no // concurrent access is allowed. cursor_.reset(); db_.reset(); db_type_ = db_type; source_ = source; db_ = CreateDB(db_type_, source_, READ); CAFFE_ENFORCE( db_, "Cannot find db implementation of type ", db_type, " (while trying to open ", source_, ")"); InitializeCursor(num_shards, shard_id); } void Open( unique_ptr&& db, const int32_t num_shards = 1, const int32_t shard_id = 0) { cursor_.reset(); db_.reset(); db_ = std::move(db); CAFFE_ENFORCE(db_.get(), "Passed null db"); InitializeCursor(num_shards, shard_id); } public: /** * Read a set of key and value from the db and move to next. Thread safe. * * The string objects key and value must be created by the caller and * explicitly passed in to this function. This saves one additional object * copy. * * If the cursor reaches its end, the reader will go back to the head of * the db. This function can be used to enable multiple input ops to read * the same db. * * Note(jiayq): we loosen the definition of a const function here a little * bit: the state of the cursor is actually changed. However, this allows * us to pass in a DBReader to an Operator without the need of a duplicated * output blob. */ void Read(string* key, string* value) const { CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized."); std::unique_lock mutex_lock(reader_mutex_); *key = cursor_->key(); *value = cursor_->value(); // In sharded mode, each read skips num_shards_ records for (uint32_t s = 0; s < num_shards_; s++) { cursor_->Next(); if (!cursor_->Valid()) { MoveToBeginning(); break; } } } /** * @brief Seeks to the first key. Thread safe. */ void SeekToFirst() const { CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized."); std::unique_lock mutex_lock(reader_mutex_); MoveToBeginning(); } /** * Returns the underlying cursor of the db reader. * * Note that if you directly use the cursor, the read will not be thread * safe, because there is no mechanism to stop multiple threads from * accessing the same cursor. You should consider using Read() explicitly. */ inline Cursor* cursor() const { VLOG(1) << "Usually for a DBReader you should use Read() to be " "thread safe. Consider refactoring your code."; return cursor_.get(); } private: void InitializeCursor(const int32_t num_shards, const int32_t shard_id) { CAFFE_ENFORCE(num_shards >= 1); CAFFE_ENFORCE(shard_id >= 0); CAFFE_ENFORCE(shard_id < num_shards); num_shards_ = num_shards; shard_id_ = shard_id; cursor_ = db_->NewCursor(); SeekToFirst(); } void MoveToBeginning() const { cursor_->SeekToFirst(); for (uint32_t s = 0; s < shard_id_; s++) { cursor_->Next(); CAFFE_ENFORCE( cursor_->Valid(), "Db has fewer rows than shard id: ", s, shard_id_); } } string db_type_; string source_; unique_ptr db_; unique_ptr cursor_; mutable std::mutex reader_mutex_; uint32_t num_shards_{}; uint32_t shard_id_{}; C10_DISABLE_COPY_AND_ASSIGN(DBReader); }; class CAFFE2_API DBReaderSerializer : public BlobSerializerBase { public: /** * Serializes a DBReader. Note that this blob has to contain DBReader, * otherwise this function produces a fatal error. */ void Serialize( const void* pointer, TypeMeta typeMeta, const string& name, BlobSerializerBase::SerializationAcceptor acceptor) override; }; class CAFFE2_API DBReaderDeserializer : public BlobDeserializerBase { public: void Deserialize(const BlobProto& proto, Blob* blob) override; }; } // namespace db } // namespace caffe2 #endif // CAFFE2_CORE_DB_H_