#ifndef CAFFE2_NET_ASYNC_TASK_GRAPH_H #define CAFFE2_NET_ASYNC_TASK_GRAPH_H #include "caffe2/core/net_async_base.h" #include "caffe2/core/net_async_task.h" #include "caffe2/core/net_async_task_future.h" #include "caffe2/core/operator.h" namespace caffe2 { // AsyncTaskGraph represents an execution of a net, it owns the tasks and // associated futures, sets up future callbacks and propagates errors. // Usage steps: // - Adding graph nodes and edges through CreateNode/AddDependency; // - Freezing the graph (FreezeGraph), after the freezing a future // can be obtained using GetFuture; // - Execution of the graph is scheduled through ExecuteGraph, after each // execution Reset must be called to prepare the graph for the next run class AsyncTaskGraphBase { public: virtual bool CreateNode( int node_id, const std::vector& ops) = 0; virtual bool AddDependency( int child_node_id, const std::vector& parent_node_ids) = 0; virtual void FreezeGraph() = 0; virtual AsyncTaskFuture* ExecuteGraph() = 0; virtual AsyncTaskFuture* GetFuture() = 0; virtual void Reset() = 0; virtual ~AsyncTaskGraphBase() noexcept {} }; class AsyncTaskGraph : public AsyncTaskGraphBase { public: AsyncTaskGraph(ExecutorHelper* helper, const ExecutionOptions& options); bool CreateNode(int node_id, const std::vector& ops) override; bool AddDependency(int child_node_id, const std::vector& parent_node_ids) override; void FreezeGraph() override; AsyncTaskFuture* ExecuteGraph() override; AsyncTaskFuture* GetFuture() override; void Reset() override; private: // used to, e.g., get access to executor's thread pools // TODO: pass tracer and counters through ExecutorHelper ExecutorHelper* helper_; ExecutionOptions options_; bool frozen_; std::unordered_map> nodes_; std::unordered_map> parents_; std::unordered_map> children_; std::vector> edge_futures_; std::vector root_tasks_; std::unique_ptr run_future_; }; } // namespace caffe2 #endif // CAFFE2_NET_ASYNC_TASK_GRAPH_H