reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#ifndef CAFFE2_NET_ASYNC_TASK_GRAPH_H
#define CAFFE2_NET_ASYNC_TASK_GRAPH_H
 
#include "caffe2/core/net_async_base.h"
#include "caffe2/core/net_async_task.h"
#include "caffe2/core/net_async_task_future.h"
#include "caffe2/core/operator.h"
 
namespace caffe2 {
 
// AsyncTaskGraph represents an execution of a net, it owns the tasks and
// associated futures, sets up future callbacks and propagates errors.
// Usage steps:
// - Adding graph nodes and edges through CreateNode/AddDependency;
// - Freezing the graph (FreezeGraph), after the freezing a future
//   can be obtained using GetFuture;
// - Execution of the graph is scheduled through ExecuteGraph, after each
//   execution Reset must be called to prepare the graph for the next run
 
class AsyncTaskGraphBase {
 public:
  virtual bool CreateNode(
      int node_id,
      const std::vector<OperatorBase*>& ops) = 0;
 
  virtual bool AddDependency(
      int child_node_id,
      const std::vector<int>& parent_node_ids) = 0;
 
  virtual void FreezeGraph() = 0;
 
  virtual AsyncTaskFuture* ExecuteGraph() = 0;
 
  virtual AsyncTaskFuture* GetFuture() = 0;
 
  virtual void Reset() = 0;
 
  virtual ~AsyncTaskGraphBase() noexcept {}
};
 
class AsyncTaskGraph : public AsyncTaskGraphBase {
 public:
  AsyncTaskGraph(ExecutorHelper* helper, const ExecutionOptions& options);
 
  bool CreateNode(int node_id, const std::vector<OperatorBase*>& ops) override;
 
  bool AddDependency(int child_node_id, const std::vector<int>& parent_node_ids)
      override;
 
  void FreezeGraph() override;
 
  AsyncTaskFuture* ExecuteGraph() override;
 
  AsyncTaskFuture* GetFuture() override;
 
  void Reset() override;
 
 private:
  // used to, e.g., get access to executor's thread pools
  // TODO: pass tracer and counters through ExecutorHelper
  ExecutorHelper* helper_;
  ExecutionOptions options_;
 
  bool frozen_;
 
  std::unordered_map<int, std::unique_ptr<AsyncTask>> nodes_;
  std::unordered_map<int, std::unordered_set<int>> parents_;
  std::unordered_map<int, std::unordered_set<int>> children_;
  std::vector<std::unique_ptr<AsyncTaskFuture>> edge_futures_;
 
  std::vector<AsyncTask*> root_tasks_;
 
  std::unique_ptr<AsyncTaskFuture> run_future_;
};
 
} // namespace caffe2
 
#endif // CAFFE2_NET_ASYNC_TASK_GRAPH_H