#ifndef CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ #define CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ #include #include #include #include "caffe2/core/blob_serialization.h" #include "caffe2/core/context.h" #include "caffe2/core/db.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" #include "caffe2/utils/proto_utils.h" namespace caffe2 { namespace { struct BlobState { int64_t total_size; int64_t current_size; bool is_tensor; std::set seen_chunks_ids; explicit BlobState( int64_t total_size = 0, int64_t current_size = 0, bool is_tensor = false) : total_size(total_size), current_size(current_size), is_tensor(is_tensor) {} }; } // namespace using db::Cursor; using db::DB; using db::Transaction; template class DBExistsOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; explicit DBExistsOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), ws_(ws), absolute_path_( this->template GetSingleArgument("absolute_path", false)), db_name_(this->template GetSingleArgument("db_name", "")), db_type_(this->template GetSingleArgument("db_type", "")) {} bool RunOnDevice() override { string full_db_name = absolute_path_ ? db_name_ : (ws_->RootFolder() + "/" + db_name_); auto* output = Output(0); output->Resize(); bool* exists = output->template mutable_data(); *exists = caffe2::db::DBExists(db_type_, full_db_name); return true; } private: Workspace* ws_; bool absolute_path_; std::string db_name_; std::string db_type_; }; template class LoadOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; explicit LoadOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), ws_(ws), absolute_path_( this->template GetSingleArgument("absolute_path", false)), add_prefix_(this->template GetSingleArgument("add_prefix", "")), strip_prefix_( this->template GetSingleArgument("strip_prefix", "")), db_name_(this->template GetSingleArgument("db", "")), db_names_(this->template GetRepeatedArgument("dbs")), db_type_(this->template GetSingleArgument("db_type", "")), keep_device_(this->template GetSingleArgument("keep_device", 0)), load_all_(this->template GetSingleArgument("load_all", 0)), allow_incomplete_( this->template GetSingleArgument("allow_incomplete", false)), blob_names_( this->template GetRepeatedArgument("source_blob_names")), shape_(this->template GetRepeatedArgument("shape")) { if (InputSize() == 0) { CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type."); if (db_names_.empty()) { CAFFE_ENFORCE_GT(db_name_.size(), 0, "Must specify a db name."); db_names_.push_back(db_name_); db_name_ = ""; } else { std::set db_name_set; for (const string& db_name : db_names_) { CAFFE_ENFORCE_GT(db_name.size(), 0, "Db name should not be empty."); CAFFE_ENFORCE( db_name_set.insert(db_name).second, "Duplicated db name: ", db_name); } db_name_ = ""; } } CAFFE_ENFORCE( blob_names_.empty() || blob_names_.size() == OutputSize(), "Number of output blobs and source_blob_names mismatch."); CAFFE_ENFORCE( blob_names_.empty() || strip_prefix_.empty(), "strip_prefix and source_blob_names are mutually exclusive."); CAFFE_ENFORCE( blob_names_.empty() || !load_all_, "cannot load_all_ while using source_blob_names."); if (!load_all_) { // blob_names_ will be filled with ''source blob names'' in file/db // if argument source_blob_names is not given, then blob_names_ is // inferred from operator output if (blob_names_.empty()) { for (const string& name : operator_def.output()) { blob_names_.push_back(name); } } int idx = 0; std::set name_set; for (const string& name : blob_names_) { CAFFE_ENFORCE( name_set.insert(name).second, "Duplicated source blob name: ", name); output_indices_[name] = idx++; } } } void SetCurrentDevice(BlobProto* proto); bool RunOnDevice() override { int total_loaded_blobs = 0; std::unordered_map blob_states; if (InputSize() > 0) { for (int i = 0; i < InputSize(); ++i) { const db::DBReader& reader = this->template Input(i); extract(i, reader.cursor(), &blob_states, &total_loaded_blobs); } } else { for (int i = 0; i < db_names_.size(); ++i) { string full_db_name = absolute_path_ ? db_names_[i] : (ws_->RootFolder() + "/" + db_names_[i]); std::unique_ptr in_db( caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::READ)); CAFFE_ENFORCE( in_db.get(), "Cannot find db implementation of type ", db_type_, " (while trying to open ", full_db_name, ")"); std::unique_ptr cursor(in_db->NewCursor()); extract(i, cursor.get(), &blob_states, &total_loaded_blobs); } } validateBlobStates(blob_states); // Loaded all the needed blobs. if (!load_all_ && total_loaded_blobs == OutputSize()) { VLOG(1) << "Loaded " << total_loaded_blobs << " blobs fully from db(s)"; return true; } if (load_all_) { for (const string& name : this->debug_def().output()) { CAFFE_ENFORCE( blob_states.count(name), "Output blob name ", name, " does not exist in the db(s)."); } return true; } // Only loaded a subset of the blobs. if (allow_incomplete_) { VLOG(1) << "Loaded " << total_loaded_blobs << " blobs out of " << OutputSize() << " blobs from db(s)."; } else { for (const string& output_name : this->debug_def().output()) { if (blob_states.count(output_name) == 0) { LOG(ERROR) << "Failed to load blob: " << output_name; } } CAFFE_THROW( "Expected to load ", OutputSize(), " blobs, got ", total_loaded_blobs, " only.\n"); } return true; } private: void extract( int db_id, Cursor* cursor, std::unordered_map* blob_states, int* total_loaded_blobs) { if (load_all_) { extractAll(db_id, cursor, blob_states, total_loaded_blobs); } else { extractFrom( db_id, cursor, OperatorBase::Outputs(), blob_states, total_loaded_blobs); } } void extractAll( int db_id, Cursor* cursor, std::unordered_map* blob_states, int* total_loaded_blobs) { CAFFE_ENFORCE(cursor, "cursor is not valid"); int loaded_blobs = 0; for (; cursor->Valid(); cursor->Next()) { const auto key = buildBlobNameFromDbKey(cursor->key()); if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) { CAFFE_THROW("Duplicate Key ", key, " is found!\n"); } else { key_to_dbid_[key] = db_id; } BlobProto proto; CAFFE_ENFORCE( proto.ParseFromString(cursor->value()), "Couldn't parse Proto"); if (!keep_device_) { // If we are not keeping the device as the one specified in the // proto, we will set the current device. SetCurrentDevice(&proto); } Blob* blob = ws_->CreateBlob(key); ProcessBlob(blob, proto, blob_states, key, &loaded_blobs); } *total_loaded_blobs += loaded_blobs; } void extractFrom( int db_id, Cursor* cursor, const vector& outputs, std::unordered_map* blob_states, int* total_loaded_blobs) { CAFFE_ENFORCE(cursor); int loaded_blobs = 0; for (; cursor->Valid(); cursor->Next()) { const auto key = buildBlobNameFromDbKey(cursor->key()); if (!output_indices_.count(key)) { VLOG(1) << "Key " << key << " not used. Skipping."; } else { if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) { CAFFE_THROW("Duplicate Key ", key, " is found!\n"); } else { key_to_dbid_[key] = db_id; } VLOG(2) << "Deserializing blob " << key; BlobProto proto; CAFFE_ENFORCE(proto.ParseFromString(cursor->value())); if (!keep_device_) { // If we are not keeping the device as the one specified in the // proto, we will set the current device. SetCurrentDevice(&proto); } auto blobIndex = output_indices_[key]; Blob* blob = outputs.at(blobIndex); ProcessBlob(blob, proto, blob_states, key, &loaded_blobs); if (*total_loaded_blobs + loaded_blobs == OutputSize()) { break; } } } *total_loaded_blobs += loaded_blobs; } string buildBlobNameFromDbKey(const string& dbKey) { string key = dbKey.substr(0, dbKey.find(kChunkIdSeparator)); if (!strip_prefix_.empty()) { auto match_pos = key.find(strip_prefix_); if (match_pos != string::npos) { key = key.substr(match_pos + strip_prefix_.size()); } } key = add_prefix_ + key; return key; } private: // We are tracking sizes of already read tensor parts while reading data // chunks. This way we can make sure that all chunks were loaded in the end. void ProcessBlob( Blob* blob, const BlobProto& proto, std::unordered_map* blob_states_ptr, const string& key, int* loaded_blobs) { auto& blob_states = *blob_states_ptr; if (blob_states.count(key) == 0) { // We reset the blob so that any existing content is destroyed. This // is to guaranee correct device placement: if we are deserializing // into a TensorCUDA, without explicit Reset we might be loading data // into an existing TensorCUDA that has pre-allocated memory on a // different GPU. blob->Reset(); } DeserializeBlob(proto, blob); if (proto.has_content_num_chunks()) { if (!blob_states.count(key)) { blob_states[key] = BlobState(proto.content_num_chunks()); } CAFFE_ENFORCE( blob_states[key] .seen_chunks_ids.insert(proto.content_chunk_id()) .second, "Chunk with the same id has occured twice for: ", key); CAFFE_ENFORCE( proto.content_chunk_id() >= 0 && proto.content_chunk_id() < blob_states[key].total_size, "Chunk id has to be not less than 0 and " "less than content_num_chunks for key: ", key); blob_states[key].current_size++; CAFFE_ENFORCE( !blob_states[key].is_tensor, "Proto with content_chunks can not store tensor: ", key); CAFFE_ENFORCE( blob_states[key].current_size <= blob_states[key].total_size, "Found an extra part for an already filled blob: ", key); if (blob_states[key].current_size == blob_states[key].total_size) { (*loaded_blobs)++; } return; } if (!proto.has_tensor()) { // If blob is divided into chunks the field content_chunks has to be set, // otherwise only tensors can be seen multiple times as chunks. CAFFE_ENFORCE(blob_states.count(key) == 0, "Blob duplicated: ", key); blob_states[key] = BlobState(); (*loaded_blobs)++; return; } CAFFE_ENFORCE(proto.has_tensor()); if (blob_states.count(key)) { CAFFE_ENFORCE(blob_states[key].is_tensor, "Must be tensor ", key); CAFFE_ENFORCE( blob_states[key].current_size < blob_states[key].total_size, "Found an extra part for an already filled tensor: ", key); CAFFE_ENFORCE( proto.tensor().has_segment(), "Partial tensor must have a segment: ", key); blob_states[key].current_size += proto.tensor().segment().end() - proto.tensor().segment().begin(); CAFFE_ENFORCE( blob_states[key].current_size <= blob_states[key].total_size, "Tensor parts are bigger than target size for tensor: ", key); } else { const auto& dims = proto.tensor().dims(); int64_t total_size = 1; for (const auto& dim : dims) { total_size *= dim; } auto current_size = total_size; if (proto.tensor().has_segment()) { current_size = proto.tensor().segment().end() - proto.tensor().segment().begin(); } blob_states[key] = BlobState(total_size, current_size, true /* is_tensor */); } if (blob_states[key].current_size == blob_states[key].total_size) { (*loaded_blobs)++; } } void validateBlobStates( const std::unordered_map& blob_states) { for (const auto& iter : blob_states) { const BlobState& blob_state = iter.second; CAFFE_ENFORCE( blob_state.current_size == blob_state.total_size, "Data size mismatch for blob ", iter.first, ". Expected: ", blob_state.total_size, " Read: ", blob_state.current_size); } } Workspace* ws_; bool absolute_path_; string add_prefix_; string strip_prefix_; string db_name_; std::vector db_names_; string db_type_; bool keep_device_; bool load_all_; bool allow_incomplete_; std::map output_indices_; std::map key_to_dbid_; std::vector blob_names_; std::vector shape_; }; template class SaveOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; explicit SaveOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), ws_(ws), absolute_path_( this->template GetSingleArgument("absolute_path", false)), strip_prefix_( this->template GetSingleArgument("strip_prefix", "")), db_name_(this->template GetSingleArgument("db", "")), db_type_(this->template GetSingleArgument("db_type", "")), blob_names_( this->template GetRepeatedArgument("blob_name_overrides")), chunk_size_(this->template GetSingleArgument( "chunk_size", kDefaultChunkSize)) { CAFFE_ENFORCE_GT(db_name_.size(), 0, "Must specify a db name."); CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type."); CAFFE_ENFORCE( blob_names_.empty() || blob_names_.size() == OperatorBase::Inputs().size(), "Number of blobs and blob_name_overrides mismatch."); CAFFE_ENFORCE( blob_names_.empty() || strip_prefix_.empty(), "strip_prefix and blob_name_overrides are mutually exclusive."); if (blob_names_.empty()) { std::set input_names; blob_names_.resize(OperatorBase::Inputs().size()); for (int i = 0; i < blob_names_.size(); ++i) { std::string name; if (strip_prefix_.empty()) { name = operator_def.input(i); } else { auto match_pos = operator_def.input(i).find(strip_prefix_); if (match_pos == string::npos) { name = operator_def.input(i); } else { name = operator_def.input(i).substr( match_pos + strip_prefix_.size(), string::npos); } } CAFFE_ENFORCE( input_names.insert(name).second, "Duplicated input: ", name); blob_names_[i] = name; } } } bool RunOnDevice() override { string full_db_name = absolute_path_ ? db_name_ : (ws_->RootFolder() + "/" + db_name_); std::unique_ptr out_db( caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::NEW)); CAFFE_ENFORCE( out_db.get(), "Cannot find db implementation of type ", db_type_, " (while trying to open ", full_db_name, ")"); BlobSerializerBase::SerializationAcceptor acceptor = [&]( const std::string& blobName, const std::string& data) { // transaction should take care of locking VLOG(2) << "Sending " << blobName << " blob's data of size " << data.size() << " to db"; auto transaction = out_db->NewTransaction(); transaction->Put(blobName, data); transaction->Commit(); }; const vector& inputs = OperatorBase::Inputs(); VLOG(0) << "Saving " << inputs.size() << " inputs to " << db_type_ << ": " << full_db_name; for (int i = 0; i < inputs.size(); ++i) { SerializeBlob(*inputs[i], blob_names_[i], acceptor, chunk_size_); } out_db->Close(); return true; } private: Workspace* ws_; bool absolute_path_; string strip_prefix_; string db_name_; string db_type_; std::vector blob_names_; int chunk_size_; }; template string FormatString(const string& pattern, Ts... values) { // Note(Yangqing): We believe that 1024 is enough, but who are we to assert // that? // As a result, if things go wrong, we'll just throw the towel and quit loud. // Yeah, I know that there is snprintf, but it is not present in *some* // platforms unfortunately. char buffer[1024]; int written = sprintf(buffer, pattern.c_str(), values...); if (written < 0 || written + 1 > 1024) { LOG(FATAL) << "FormatString fails: total bytes written " << written; } return string(buffer); /* * The following is the snprintf version that is safe; enable it one day? unsigned int required = std::snprintf(nullptr, 0, pattern.c_str(), values...) + 1; char bytes[required]; std::snprintf(bytes, required, pattern.c_str(), values...); return string(bytes); */ } // CheckpointOp is a wrapper over a SaveFloatTensorOp that basically allows // flexible naming over iterations. // The file pattern in db_name should be a format string that can be passed into // sprintf with an int argument specifying the current iteration. An example: // "/path/to/my/checkpoint/checkpoint_at_%d.pb" template class CheckpointOp final : public Operator { public: explicit CheckpointOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), db_pattern_(this->template GetSingleArgument("db", "")), every_(this->template GetSingleArgument("every", 1)), ws_(ws), save_op_def_(operator_def) { CAFFE_ENFORCE_GT( db_pattern_.size(), 0, "Must specify a checkpoint file pattern."); CAFFE_ENFORCE_GT(every_, 0, "Checkpoint interval should be positive."); if (every_ == 1) { // Just issue a warning, but it's totally legal so we don't do anything. LOG(WARNING) << "It seems that we are checkpointting every iteration. " << "Is that intended?"; } save_op_def_.set_type("Save"); } USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override { int64_t iter = this->template Input(0, CPU).template data()[0]; if (iter % every_ == 0) { GetMutableArgument("db", true, &save_op_def_) ->set_s(FormatString(db_pattern_, iter)); SaveOp sub_op(save_op_def_, ws_); return sub_op.Run(); } else { return true; } } private: string db_pattern_; int every_; Workspace* ws_; OperatorDef save_op_def_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_LOAD_SAVE_OP_H_