reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-10 c3765bd24fe73747688a0ec2a550f219c9acb384
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#pragma once
 
// Engine implements backpropagation from output variables and their gradients
// to "root" variables (variables created by the user with requires_grad=True).
 
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/input_buffer.h>
#include <torch/csrc/autograd/anomaly_mode.h>
 
#include <deque>
#include <exception>
#include <functional>
#include <memory>
#include <queue>
#include <unordered_map>
#include <utility>
#include <vector>
#include <thread>
 
namespace torch { namespace autograd {
struct ReadyQueue;
struct NodeTask;
struct GraphTask;
}} // namespace torch::autograd
 
namespace torch { namespace autograd {
// A single instance of this struct should be created through the whole process lifetime.
// The worker thread creation logic and Engine's destructor rely on this.
struct TORCH_API Engine {
  /// Returns a reference to a static `Engine` instance.
  static Engine& get_default_engine();
 
  Engine();
  virtual ~Engine();
 
  using ready_queue_type = std::deque<std::pair<std::shared_ptr<Node>, InputBuffer>>;
  using dependencies_type = std::unordered_map<Node*, int>;
 
  // Given a list of (Node, input number) pairs computes the value of the graph
  // by following next_edge references.
  virtual variable_list execute(
      const edge_list& roots,
      const variable_list& inputs,
      bool keep_graph,
      bool create_graph,
      const edge_list& outputs = {});
  virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() {
    return nullptr;
  }
 
  void queue_callback(std::function<void()> callback);
 
  bool is_checkpoint_valid();
 
protected:
  void compute_dependencies(Node* root, GraphTask& task);
  void evaluate_function(NodeTask& task);
  ReadyQueue& ready_queue(at::Device device);
  ReadyQueue& ready_queue_by_index(int device_index);
  void start_threads();
  virtual void thread_init(int device);
  virtual void thread_main(GraphTask *graph_task);
  virtual void thread_on_exception(NodeTask& task, std::exception& e);
  void reentrant_thread_init();
  void add_thread_pool_task(GraphTask *graph_task);
  void set_device(int device);
 
  // Ensures ready_queues_ are initialized only once
  std::once_flag start_threads_flag_;
  // Safe to read ready_queues_ without synchronization after intialization
  std::vector<std::shared_ptr<ReadyQueue>> ready_queues_;
  std::vector<std::function<void()>> final_callbacks_;
  // To protect reads and writes to final_callbacks_
  std::mutex post_callbacks_lock_;
  // How many nested reentrant calls are allowed until a new thread is used
  int max_recursion_depth_;
 
  struct ThreadPoolShared {
    // Data structures used by the threads for executing reentrant backwards
    // tasks. See Note [Reentrant backwards]
    // Number of available threads for processing new GraphTasks.
    unsigned int num_workers_;
    // The threads will wait on work_ to be notified of GraphTasks
    std::condition_variable work_;
    // To protect reads and writes to graphtask_queue_ and num_workers_
    // and for synchronizing creating new threads when needed
    std::mutex mutex_;
    // Workers will process the GraphTasks added to this queue. A GraphTask is
    // allocated inside Engine::execute and lives for the duration of execute
    std::queue<GraphTask*> graphtasks_queue_;
 
    ThreadPoolShared() : num_workers_(0) {}
 };
 
 // Temporary workaround until shutting down threads is done
 // We need shared ownership of all these objects because the threads are leaked
 // when Engine shuts down, so there may be threads waiting on work_
 // for the graphtasks_queue_ to be nonempty.
 std::shared_ptr<ThreadPoolShared> thread_pool_shared_;
};
 
// allow python_engine to override the default engine when it loads
using EngineStub = Engine& (*)();
TORCH_API void set_default_engine_stub(EngineStub stub);
 
}} // namespace torch::autograd