#ifndef CAFFE2_CORE_NET_PARALLEL_H #define CAFFE2_CORE_NET_PARALLEL_H #include "caffe2/core/net_async_base.h" #include "caffe2/core/net_async_task_graph.h" C10_DECLARE_string(caffe2_task_graph_engine); namespace caffe2 { class ParallelNetExecutorHelper; class CAFFE2_API ParallelNet : public NetBase { public: ParallelNet(const std::shared_ptr& net_def, Workspace* ws); bool RunAsync() override; void Wait() override; bool SupportsAsync() override; std::vector GetOperators() const override; TaskThreadPoolBase* Pool(const DeviceOption& device_option); protected: bool handleRunError() override; virtual void finishRun(); virtual void reset(); ExecutionOptions options_; int num_workers_; std::unique_ptr helper_; std::shared_ptr task_graph_; AsyncTaskFuture* run_future_; std::vector operator_nodes_; std::vector operators_; std::mutex pools_mutex_; typedef std::unordered_map< int, std::unordered_map>> PoolsMap; PoolsMap cpu_pools_; PoolsMap gpu_pools_; TaskThreadPoolBase* poolGetter(PoolsMap& pools, int device_type, int device_id, int pool_size); friend class ParallelNetExecutorHelper; C10_DISABLE_COPY_AND_ASSIGN(ParallelNet); }; C10_DECLARE_SHARED_REGISTRY( TaskGraphRegistry, AsyncTaskGraphBase, ExecutorHelper*, const ExecutionOptions&); std::shared_ptr GetAsyncTaskGraph( ExecutorHelper* helper, const ExecutionOptions& options); class ParallelNetExecutorHelper : public ExecutorHelper { public: explicit ParallelNetExecutorHelper(ParallelNet* net) : net_(net) {} TaskThreadPoolBase* GetPool(const DeviceOption& option) const override { return net_->Pool(option); } std::vector GetOperators() const override { return net_->GetOperators(); } int GetNumWorkers() const override { return net_->num_workers_; } private: ParallelNet* net_; }; } // namespace caffe2 #endif // CAFFE2_CORE_NET_PARALLEL_H