#ifndef CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_ #define CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_ #include #include #include #include #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/proto/caffe2_pb.h" C10_DECLARE_bool(caffe2_workspace_stack_debug); namespace caffe2 { namespace detail { /* * Keeps track of forward and backward gradient workspaces in stack, * reuses previously created workspaces, non-thread safe */ class CAFFE2_API WorkspaceStack { public: explicit WorkspaceStack() : parent_ws_(nullptr), top_(-1) {} std::shared_ptr pushForwardWorkspace(Workspace* parent_ws) { return pushForwardWorkspace( parent_ws, std::unordered_map()); } std::shared_ptr pushForwardWorkspace( Workspace* parent_ws, const std::unordered_map& blob_bindings) { checkStack(); if (FLAGS_caffe2_workspace_stack_debug) { if (parent_ws_) { CAFFE_ENFORCE_EQ(parent_ws_, parent_ws, "Parent workspace mismatch"); } else { parent_ws_ = parent_ws; } if (!blob_bindings_.empty()) { checkBindingsMatch(blob_bindings_, blob_bindings); } else { blob_bindings_ = blob_bindings; } } if (top_ == workspaces_.size() - 1) { workspaces_.push_back( std::make_shared(parent_ws, blob_bindings)); } else { // when reusing workspace, make sure copies of external blobs are // removed and blob bindings are set auto& workspace = workspaces_[top_ + 1]; const auto& local_blobs = workspace->LocalBlobs(); std::unordered_set local_blobs_set; local_blobs_set.insert(local_blobs.begin(), local_blobs.end()); bool found_local_copy = false; for (const auto& blob_pair : blob_bindings) { if (local_blobs_set.count(blob_pair.first)) { workspace->RemoveBlob(blob_pair.first); found_local_copy = true; } } if (found_local_copy) { workspace->AddBlobMapping(parent_ws, blob_bindings); } } return workspaces_[++top_]; } std::shared_ptr popGradientWorkspace( Workspace* parent_ws, const std::unordered_map& grad_blob_bindings) { checkStack(); if (FLAGS_caffe2_workspace_stack_debug) { if (parent_ws_) { CAFFE_ENFORCE_EQ(parent_ws_, parent_ws, "Parent workspace mismatch"); } else { parent_ws_ = parent_ws; } if (!grad_blob_bindings_.empty()) { checkBindingsMatch(grad_blob_bindings_, grad_blob_bindings); } else { grad_blob_bindings_ = grad_blob_bindings; } } if (top_ < 0) { return nullptr; } auto& grad_workspace = workspaces_[top_]; grad_workspace->AddBlobMapping(parent_ws, grad_blob_bindings, true); --top_; return grad_workspace; } std::shared_ptr reuseLastForwardWorkspace(Workspace* parent_ws) { return reuseLastForwardWorkspace( parent_ws, std::unordered_map()); } std::shared_ptr reuseLastForwardWorkspace( Workspace* parent_ws, const std::unordered_map& blob_bindings) { checkStack(); if (top_ < 0) { return nullptr; } workspaces_[top_]->AddBlobMapping(parent_ws, blob_bindings); return workspaces_[top_]; } void clear() { checkStack(); top_ = -1; } bool empty() const { return top_ < 0; } private: void checkStack() const { CAFFE_ENFORCE_GT( (int)workspaces_.size(), top_, "Corrupted workspaces stack"); } void checkBindingsMatch( const std::unordered_map& bindings, const std::unordered_map& test_bindings) const { CAFFE_ENFORCE_EQ( bindings.size(), test_bindings.size(), "Blob bindings mismatch"); for (const auto& blob_binding : bindings) { CAFFE_ENFORCE( test_bindings.count(blob_binding.first), "Blob bindings mismatch"); CAFFE_ENFORCE_EQ( test_bindings.at(blob_binding.first), blob_binding.second, "Blob bindings mismatch"); } } std::unordered_map blob_bindings_; std::unordered_map grad_blob_bindings_; Workspace* parent_ws_; int top_; std::vector> workspaces_; }; } template class CreateScopeOp final : public Operator { public: template explicit CreateScopeOp(Args&&... args) : Operator(std::forward(args)...) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; }; template class HasScopeOp final : public Operator { public: template explicit HasScopeOp(Args&&... args) : Operator(std::forward(args)...) {} USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() override; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_