#ifndef CAFFE2_OPERATORS_ONNX_WHILE_OP_H_ #define CAFFE2_OPERATORS_ONNX_WHILE_OP_H_ #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/operators/create_scope_op.h" namespace caffe2 { template class ONNXWhileOp final : public Operator { public: explicit ONNXWhileOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), parent_ws_(ws), has_trip_count_( this->template GetSingleArgument("has_trip_count", 0)), has_cond_(this->template GetSingleArgument("has_cond", 0)), save_scopes_( this->template GetSingleArgument("save_scopes", 0)), disable_scopes_( this->template GetSingleArgument("disable_scopes", 0)), num_loop_carried_deps_(this->template GetSingleArgument( "num_loop_carried_deps", -1)) { CAFFE_ENFORCE( this->template HasSingleArgumentOfType("body"), "body net must be specified in ONNXWhile operator"); if (disable_scopes_) { CAFFE_ENFORCE(!save_scopes_, "Cannot save scopes when disable_scopes=True"); } body_net_def_ = this->template GetSingleArgument("body", NetDef()); static int64_t counter = -1; if (!body_net_def_.has_name()) { if (counter == -1) { ++counter; body_net_def_.set_name("loop_net"); } else { ++counter; body_net_def_.set_name("loop_net." + c10::to_string(counter)); } } } USE_OPERATOR_CONTEXT_FUNCTIONS; bool RunOnDevice() { return DispatchHelper>::call(this, Input(1)); } // Operator // Inputs: max trip count, condition, initial loop-carried dependencies // Outputs: Final loop-carried dependencies, scan_outputs // Body // Inputs: iteration number, condition, loop-carried dependencies // Outputs: condition, loop-carried dependencies, scan_outputs template bool DoRunWithType() { // Clear workspaces from the previous invocations of the loop // and setup a local scope for the first iteration ws_stack_.clear(); auto loop_ws = !disable_scopes_ ? ws_stack_.pushForwardWorkspace(parent_ws_).get() : parent_ws_; constexpr int64_t num_inputs_before_lcds = 2; // First input is the maximumt trip count. Second input is the condition // variable (for the first iteration). The rest of the inputs are // loop-carried dependencies. int64_t num_loop_carried_deps; if (num_loop_carried_deps_ != -1) { num_loop_carried_deps = num_loop_carried_deps_; } else { num_loop_carried_deps = InputSize() - num_inputs_before_lcds; } int64_t max_trip_count = *Input(0).template data(); const bool first_iter_condition = *Input(1).template data(); scope_ = std::make_shared(loop_ws, body_net_def_, num_loop_carried_deps); // Body graph has 1+N+K outputs: recalculated condition variable, N // loop-carried dependencies, and K scan_outputs int num_scan_outputs = scope_->net()->external_output().size() - num_loop_carried_deps - 1; CAFFE_ENFORCE_GE( num_scan_outputs, 0, "Body graph must have N+K outputs, where N is the number " "of loop-carried dependencies and K is the number of scan " "outputs"); // Copy initial loop-carried dependencies for (int i = 0; i < num_loop_carried_deps; ++i) { scope_->lcd_tensor(i)->CopyFrom(Input(i + num_inputs_before_lcds)); } // Initialize iteration variable scope_->set_iteration(0ll); // Initialize input condition variable scope_->template set_input_condition(first_iter_condition); auto valid_iter_num = [this, max_trip_count](int64_t i) { if (has_trip_count_) { return i < max_trip_count; } else { return true; } }; auto condition_true = [this, first_iter_condition](int64_t i, bool cond_value) { if (has_cond_) { if (i == 0) { return (bool)first_iter_condition; } else { return cond_value; } } else { return true; } }; // Allocate scan_outputs for zero-iteration case for (int i = 0; i < num_scan_outputs; ++i) { Output(i + num_loop_carried_deps)->Resize(0); Output(i + num_loop_carried_deps)->template mutable_data(); } // Use this to keep track of the sizes of the scan outputs and validate // they're the same across iterations. std::vector> scan_outputs_sizes; Workspace *cur_ws = nullptr; bool cur_output_condition = false; while (true) { int64_t itr = scope_->iteration(); if (valid_iter_num(itr) && condition_true(itr, cur_output_condition)) { if (!scope_->net()->Run()) { return false; } cur_ws = scope_->workspace(); cur_output_condition = scope_->template output_condition(); if (save_scopes_) { loop_ws = ws_stack_.pushForwardWorkspace(parent_ws_).get(); scope_ = std::make_shared(loop_ws, body_net_def_, num_loop_carried_deps); } // Copy forward loop-carried dependencies for (int i = 0; i < num_loop_carried_deps; ++i) { Blob* b = cur_ws->GetBlob( scope_->net()->external_output()[i + 1]); const Tensor& t = b->template Get(); scope_->lcd_tensor(i)->CopyFrom(t); } // Copy out scan_outputs for (int i = 0; i < num_scan_outputs; ++i) { int net_output_idx = i + 1 + num_loop_carried_deps; const Tensor& scan_output = cur_ws->GetBlob(scope_->net()->external_output()[net_output_idx]) ->template Get(); auto* scan_output_target = Output(i + num_loop_carried_deps); if (itr == 0) { auto dims = scan_output.sizes().vec(); scan_outputs_sizes.push_back(dims); dims.insert(dims.begin(), 1); scan_output_target->Resize(dims); scan_output_target->CopyFrom(scan_output); } else { auto dims = scan_output.sizes().vec(); CAFFE_ENFORCE_EQ( dims, scan_outputs_sizes[i], "Size of scan output changed across iterations"); dims.insert(dims.begin(), itr); scan_output_target->Extend(1, 100); int64_t timestep_size = 1; for (const int64_t t : scan_outputs_sizes[i]) { timestep_size *= t; } const void* src_data = scan_output.raw_data(); auto& sot_meta = scan_output_target->dtype(); void* dst_data = (char*)scan_output_target->raw_mutable_data(sot_meta) + timestep_size * scan_output.itemsize() * itr; memcpy(dst_data, src_data, timestep_size * scan_output.itemsize()); } } scope_->set_iteration(itr + 1ll); scope_->template set_input_condition(cur_output_condition); } else { break; } } // Copy out final loop-carried dependencies for (int i = 0; i < num_loop_carried_deps; ++i) { Output(i)->CopyFrom(*scope_->lcd_tensor(i)); } return true; } private: class LocalScope { public: LocalScope( Workspace *loop_ws, const NetDef& body_net_def, size_t num_lcds) : loop_ws_(loop_ws){ CAFFE_ENFORCE(loop_ws_, "Failed to initialize local loop workspace"); // Create loop-carried deps in Workspace lcd_tensors_.clear(); for (int i = 2; i < num_lcds + 2; ++i) { Blob* b = loop_ws_->CreateBlob(body_net_def.external_input(i)); Tensor* t = BlobGetMutableTensor(b, Context::GetDeviceType()); lcd_tensors_.push_back(t); } // First output is the iteration variable auto* iteration_var_blob = loop_ws_->CreateBlob( body_net_def.external_input(0)); iteration_var_ = BlobGetMutableTensor(iteration_var_blob, Context::GetDeviceType()); input_condition_var_ = BlobGetMutableTensor( loop_ws_->CreateBlob(body_net_def.external_input(1)), Context::GetDeviceType()); auto* condition_var_blob = loop_ws_->CreateBlob(body_net_def.external_output(0)); condition_var_ = BlobGetMutableTensor(condition_var_blob, Context::GetDeviceType()); condition_var_->Resize(1); condition_var_->template mutable_data(); body_net_ = loop_ws_->GetNet(body_net_def.name()); if (!body_net_) { body_net_ = loop_ws_->CreateNet(body_net_def, true); } CAFFE_ENFORCE(body_net_, "Failed to initialize loop subnet"); } NetBase* net() const { return body_net_; } Workspace* workspace() const { return loop_ws_; } int64_t iteration() const { auto* iteration_var_ptr = iteration_var_->template mutable_data(); return *iteration_var_ptr; } Tensor* lcd_tensor(int idx) { return lcd_tensors_[idx]; } void set_iteration(int64_t itr) { iteration_var_->Resize(); auto* iteration_var_ptr = iteration_var_->template mutable_data(); *iteration_var_ptr = itr; } template void set_input_condition(bool cond_value) { input_condition_var_->Resize(1); auto* input_condition_var_ptr = input_condition_var_->template mutable_data(); *input_condition_var_ptr = cond_value; } template bool output_condition() const { auto* condition_var_ptr = condition_var_->template mutable_data(); return *condition_var_ptr; } private: Workspace *loop_ws_; NetBase* body_net_; // owned by a workspace Tensor* iteration_var_; Tensor* input_condition_var_; Tensor* condition_var_; std::vector lcd_tensors_; }; NetDef body_net_def_; Workspace* parent_ws_; detail::WorkspaceStack ws_stack_; bool has_trip_count_; bool has_cond_; bool save_scopes_; bool disable_scopes_; int64_t num_loop_carried_deps_; std::shared_ptr scope_; }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_ONNX_WHILE_OP_H