#ifndef CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_
|
#define CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_
|
|
#include <string>
|
#include <unordered_map>
|
#include <unordered_set>
|
#include <vector>
|
|
#include "caffe2/core/context.h"
|
#include "caffe2/core/logging.h"
|
#include "caffe2/core/operator.h"
|
#include "caffe2/proto/caffe2_pb.h"
|
|
C10_DECLARE_bool(caffe2_workspace_stack_debug);
|
|
namespace caffe2 {
|
namespace detail {
|
|
/*
|
* Keeps track of forward and backward gradient workspaces in stack,
|
* reuses previously created workspaces, non-thread safe
|
*/
|
class CAFFE2_API WorkspaceStack {
|
public:
|
explicit WorkspaceStack() : parent_ws_(nullptr), top_(-1) {}
|
|
std::shared_ptr<Workspace> pushForwardWorkspace(Workspace* parent_ws) {
|
return pushForwardWorkspace(
|
parent_ws, std::unordered_map<std::string, std::string>());
|
}
|
|
std::shared_ptr<Workspace> pushForwardWorkspace(
|
Workspace* parent_ws,
|
const std::unordered_map<std::string, std::string>& blob_bindings) {
|
checkStack();
|
if (FLAGS_caffe2_workspace_stack_debug) {
|
if (parent_ws_) {
|
CAFFE_ENFORCE_EQ(parent_ws_, parent_ws, "Parent workspace mismatch");
|
} else {
|
parent_ws_ = parent_ws;
|
}
|
if (!blob_bindings_.empty()) {
|
checkBindingsMatch(blob_bindings_, blob_bindings);
|
} else {
|
blob_bindings_ = blob_bindings;
|
}
|
}
|
|
if (top_ == workspaces_.size() - 1) {
|
workspaces_.push_back(
|
std::make_shared<Workspace>(parent_ws, blob_bindings));
|
} else {
|
// when reusing workspace, make sure copies of external blobs are
|
// removed and blob bindings are set
|
auto& workspace = workspaces_[top_ + 1];
|
const auto& local_blobs = workspace->LocalBlobs();
|
std::unordered_set<std::string> local_blobs_set;
|
local_blobs_set.insert(local_blobs.begin(), local_blobs.end());
|
bool found_local_copy = false;
|
for (const auto& blob_pair : blob_bindings) {
|
if (local_blobs_set.count(blob_pair.first)) {
|
workspace->RemoveBlob(blob_pair.first);
|
found_local_copy = true;
|
}
|
}
|
if (found_local_copy) {
|
workspace->AddBlobMapping(parent_ws, blob_bindings);
|
}
|
}
|
|
return workspaces_[++top_];
|
}
|
|
std::shared_ptr<Workspace> popGradientWorkspace(
|
Workspace* parent_ws,
|
const std::unordered_map<std::string, std::string>& grad_blob_bindings) {
|
checkStack();
|
if (FLAGS_caffe2_workspace_stack_debug) {
|
if (parent_ws_) {
|
CAFFE_ENFORCE_EQ(parent_ws_, parent_ws, "Parent workspace mismatch");
|
} else {
|
parent_ws_ = parent_ws;
|
}
|
if (!grad_blob_bindings_.empty()) {
|
checkBindingsMatch(grad_blob_bindings_, grad_blob_bindings);
|
} else {
|
grad_blob_bindings_ = grad_blob_bindings;
|
}
|
}
|
|
if (top_ < 0) {
|
return nullptr;
|
}
|
auto& grad_workspace = workspaces_[top_];
|
grad_workspace->AddBlobMapping(parent_ws, grad_blob_bindings, true);
|
--top_;
|
return grad_workspace;
|
}
|
|
std::shared_ptr<Workspace> reuseLastForwardWorkspace(Workspace* parent_ws) {
|
return reuseLastForwardWorkspace(
|
parent_ws, std::unordered_map<std::string, std::string>());
|
}
|
|
std::shared_ptr<Workspace> reuseLastForwardWorkspace(
|
Workspace* parent_ws,
|
const std::unordered_map<std::string, std::string>& blob_bindings) {
|
checkStack();
|
if (top_ < 0) {
|
return nullptr;
|
}
|
workspaces_[top_]->AddBlobMapping(parent_ws, blob_bindings);
|
return workspaces_[top_];
|
}
|
|
void clear() {
|
checkStack();
|
top_ = -1;
|
}
|
|
bool empty() const {
|
return top_ < 0;
|
}
|
|
private:
|
void checkStack() const {
|
CAFFE_ENFORCE_GT(
|
(int)workspaces_.size(), top_, "Corrupted workspaces stack");
|
}
|
|
void checkBindingsMatch(
|
const std::unordered_map<std::string, std::string>& bindings,
|
const std::unordered_map<std::string, std::string>& test_bindings) const {
|
CAFFE_ENFORCE_EQ(
|
bindings.size(), test_bindings.size(), "Blob bindings mismatch");
|
for (const auto& blob_binding : bindings) {
|
CAFFE_ENFORCE(
|
test_bindings.count(blob_binding.first), "Blob bindings mismatch");
|
CAFFE_ENFORCE_EQ(
|
test_bindings.at(blob_binding.first),
|
blob_binding.second,
|
"Blob bindings mismatch");
|
}
|
}
|
|
std::unordered_map<std::string, std::string> blob_bindings_;
|
std::unordered_map<std::string, std::string> grad_blob_bindings_;
|
Workspace* parent_ws_;
|
int top_;
|
std::vector<std::shared_ptr<Workspace>> workspaces_;
|
};
|
}
|
|
template <class Context>
|
class CreateScopeOp final : public Operator<Context> {
|
public:
|
template <class... Args>
|
explicit CreateScopeOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...) {}
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
bool RunOnDevice() override;
|
};
|
|
template <class Context>
|
class HasScopeOp final : public Operator<Context> {
|
public:
|
template <class... Args>
|
explicit HasScopeOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...) {}
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
bool RunOnDevice() override;
|
};
|
|
} // namespace caffe2
|
|
#endif // CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_
|