#ifndef CAFFE2_CORE_WORKSPACE_H_ #define CAFFE2_CORE_WORKSPACE_H_ #include "caffe2/core/common.h" #include "caffe2/core/observer.h" #include #include #include #include #include #include #include "c10/util/Registry.h" #include "caffe2/core/blob.h" #include "caffe2/core/net.h" #include "caffe2/proto/caffe2_pb.h" #include "caffe2/utils/signal_handler.h" #include "caffe2/utils/threadpool/ThreadPool.h" C10_DECLARE_bool(caffe2_print_blob_sizes_at_exit); namespace caffe2 { class NetBase; struct CAFFE2_API StopOnSignal { StopOnSignal() : handler_(std::make_shared( SignalHandler::Action::STOP, SignalHandler::Action::STOP)) {} StopOnSignal(const StopOnSignal& other) : handler_(other.handler_) {} bool operator()(int /*iter*/) { return handler_->CheckForSignals() != SignalHandler::Action::STOP; } std::shared_ptr handler_; }; /** * Workspace is a class that holds all the related objects created during * runtime: (1) all blobs, and (2) all instantiated networks. It is the owner of * all these objects and deals with the scaffolding logistics. */ class CAFFE2_API Workspace { public: typedef std::function ShouldContinue; typedef CaffeMap > BlobMap; typedef CaffeMap > NetMap; /** * Initializes an empty workspace. */ Workspace() : Workspace(".", nullptr) {} /** * Initializes an empty workspace with the given root folder. * * For any operators that are going to interface with the file system, such * as load operators, they will write things under this root folder given * by the workspace. */ explicit Workspace(const string& root_folder) : Workspace(root_folder, nullptr) {} /** * Initializes a workspace with a shared workspace. * * When we access a Blob, we will first try to access the blob that exists * in the local workspace, and if not, access the blob that exists in the * shared workspace. The caller keeps the ownership of the shared workspace * and is responsible for making sure that its lifetime is longer than the * created workspace. */ explicit Workspace(const Workspace* shared) : Workspace(".", shared) {} /** * Initializes workspace with parent workspace, blob name remapping * (new name -> parent blob name), no other blobs are inherited from * parent workspace */ Workspace( const Workspace* shared, const std::unordered_map& forwarded_blobs) : Workspace(".", nullptr) { CAFFE_ENFORCE(shared, "Parent workspace must be specified"); for (const auto& forwarded : forwarded_blobs) { CAFFE_ENFORCE( shared->HasBlob(forwarded.second), "Invalid parent workspace blob: ", forwarded.second); forwarded_blobs_[forwarded.first] = std::make_pair(shared, forwarded.second); } } /** * Initializes a workspace with a root folder and a shared workspace. */ Workspace(const string& root_folder, const Workspace* shared) : root_folder_(root_folder), shared_(shared), bookkeeper_(bookkeeper()) { std::lock_guard guard(bookkeeper_->wsmutex); bookkeeper_->workspaces.insert(this); } ~Workspace() { if (FLAGS_caffe2_print_blob_sizes_at_exit) { PrintBlobSizes(); } // This is why we have a bookkeeper_ shared_ptr instead of a naked static! A // naked static makes us vulnerable to out-of-order static destructor bugs. std::lock_guard guard(bookkeeper_->wsmutex); bookkeeper_->workspaces.erase(this); } /** * Adds blob mappings from workspace to the blobs from parent workspace. * Creates blobs under possibly new names that redirect read/write operations * to the blobs in the parent workspace. * Arguments: * parent - pointer to parent workspace * forwarded_blobs - map from new blob name to blob name in parent's * workspace skip_defined_blob - if set skips blobs with names that already * exist in the workspace, otherwise throws exception */ void AddBlobMapping( const Workspace* parent, const std::unordered_map& forwarded_blobs, bool skip_defined_blobs = false); /** * Converts previously mapped tensor blobs to local blobs, copies values from * parent workspace blobs into new local blobs. Ignores undefined blobs. */ template void CopyForwardedTensors(const std::unordered_set& blobs) { for (const auto& blob : blobs) { if (!forwarded_blobs_.count(blob)) { continue; } const auto& ws_blob = forwarded_blobs_[blob]; const auto* parent_ws = ws_blob.first; auto* from_blob = parent_ws->GetBlob(ws_blob.second); CAFFE_ENFORCE(from_blob); CAFFE_ENFORCE( from_blob->template IsType(), "Expected blob with tensor value", ws_blob.second); forwarded_blobs_.erase(blob); auto* to_blob = CreateBlob(blob); CAFFE_ENFORCE(to_blob); const auto& from_tensor = from_blob->template Get(); auto* to_tensor = BlobGetMutableTensor(to_blob, Context::GetDeviceType()); to_tensor->CopyFrom(from_tensor); } } /** * Return list of blobs owned by this Workspace, not including blobs * shared from parent workspace. */ vector LocalBlobs() const; /** * Return a list of blob names. This may be a bit slow since it will involve * creation of multiple temp variables. For best performance, simply use * HasBlob() and GetBlob(). */ vector Blobs() const; /** * Return the root folder of the workspace. */ const string& RootFolder() { return root_folder_; } /** * Checks if a blob with the given name is present in the current workspace. */ inline bool HasBlob(const string& name) const { // First, check the local workspace, // Then, check the forwarding map, then the parent workspace if (blob_map_.count(name)) { return true; } else if (forwarded_blobs_.count(name)) { const auto parent_ws = forwarded_blobs_.at(name).first; const auto& parent_name = forwarded_blobs_.at(name).second; return parent_ws->HasBlob(parent_name); } else if (shared_) { return shared_->HasBlob(name); } return false; } void PrintBlobSizes(); /** * Creates a blob of the given name. The pointer to the blob is returned, but * the workspace keeps ownership of the pointer. If a blob of the given name * already exists, the creation is skipped and the existing blob is returned. */ Blob* CreateBlob(const string& name); /** * Similar to CreateBlob(), but it creates a blob in the local workspace even * if another blob with the same name already exists in the parent workspace * -- in such case the new blob hides the blob in parent workspace. If a blob * of the given name already exists in the local workspace, the creation is * skipped and the existing blob is returned. */ Blob* CreateLocalBlob(const string& name); /** * Remove the blob of the given name. Return true if removed and false if * not exist. * Will NOT remove from the shared workspace. */ bool RemoveBlob(const string& name); /** * Gets the blob with the given name as a const pointer. If the blob does not * exist, a nullptr is returned. */ const Blob* GetBlob(const string& name) const; /** * Gets the blob with the given name as a mutable pointer. If the blob does * not exist, a nullptr is returned. */ Blob* GetBlob(const string& name); /** * Renames a local workspace blob. If blob is not found in the local blob list * or if the target name is already present in local or any parent blob list * the function will throw. */ Blob* RenameBlob(const string& old_name, const string& new_name); /** * Creates a network with the given NetDef, and returns the pointer to the * network. If there is anything wrong during the creation of the network, a * nullptr is returned. The Workspace keeps ownership of the pointer. * * If there is already a net created in the workspace with the given name, * CreateNet will overwrite it if overwrite=true is specified. Otherwise, an * exception is thrown. */ NetBase* CreateNet(const NetDef& net_def, bool overwrite = false); NetBase* CreateNet( const std::shared_ptr& net_def, bool overwrite = false); /** * Gets the pointer to a created net. The workspace keeps ownership of the * network. */ NetBase* GetNet(const string& net_name); /** * Deletes the instantiated network with the given name. */ void DeleteNet(const string& net_name); /** * Finds and runs the instantiated network with the given name. If the network * does not exist or there are errors running the network, the function * returns false. */ bool RunNet(const string& net_name); /** * Returns a list of names of the currently instantiated networks. */ vector Nets() const { vector names; for (auto& entry : net_map_) { names.push_back(entry.first); } return names; } /** * Runs a plan that has multiple nets and execution steps. */ bool RunPlan(const PlanDef& plan_def, ShouldContinue should_continue = StopOnSignal{}); /* * Returns a CPU threadpool instace for parallel execution of * work. The threadpool is created lazily; if no operators use it, * then no threadpool will be created. */ ThreadPool* GetThreadPool(); // RunOperatorOnce and RunNetOnce runs an operator or net once. The difference // between RunNet and RunNetOnce lies in the fact that RunNet allows you to // have a persistent net object, while RunNetOnce creates a net and discards // it on the fly - this may make things like database read and random number // generators repeat the same thing over multiple calls. bool RunOperatorOnce(const OperatorDef& op_def); bool RunNetOnce(const NetDef& net_def); /** * Applies a function f on each workspace that currently exists. * * This function is thread safe and there is no race condition between * workspaces being passed to f in this thread and destroyed in another. */ template static void ForEach(F f) { auto bk = bookkeeper(); std::lock_guard guard(bk->wsmutex); for (Workspace* ws : bk->workspaces) { f(ws); } } public: std::atomic last_failed_op_net_position{}; private: struct Bookkeeper { std::mutex wsmutex; std::unordered_set workspaces; }; static std::shared_ptr bookkeeper(); BlobMap blob_map_; const string root_folder_; const Workspace* shared_; std::unordered_map> forwarded_blobs_; std::unique_ptr thread_pool_; std::mutex thread_pool_creation_mutex_; std::shared_ptr bookkeeper_; NetMap net_map_; C10_DISABLE_COPY_AND_ASSIGN(Workspace); }; } // namespace caffe2 #endif // CAFFE2_CORE_WORKSPACE_H_