reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 1bb4d137919cae4f57f95a2572ee612dcabb3b3d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#pragma once
 
#include <random>
#include <thread>
 
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/jit/pybind_utils.h>
 
namespace torch {
namespace throughput_benchmark {
namespace detail {
 
template <class Input, class Output, class Model>
BenchmarkExecutionStats BenchmarkHelper<Input, Output, Model>::benchmark(
    const BenchmarkConfig& config) const {
  CHECK(initialized_);
  TORCH_CHECK(
      config.num_worker_threads == 1,
      "Only parallelization by callers is supported");
 
  // We pre-generate inputs here for each of the threads. This allows us to
  // safely move inputs out for each of the threads independently and thus avoid
  // overhead from the benchmark runner itself
  std::vector<std::vector<Input>> thread_inputs(config.num_calling_threads);
  std::vector<size_t> input_iters(config.num_calling_threads);
  {
    std::random_device seeder;
    std::mt19937 engine(seeder());
    TORCH_CHECK(
        !inputs_.empty(),
        "Please provide benchmark inptus."
        "Did you forget to call add_input()? ");
    std::uniform_int_distribution<int> dist(0, inputs_.size() - 1);
 
    for (int thread_id = 0; thread_id < config.num_calling_threads;
         ++thread_id) {
      // Just in case we generate num_iters inputs for each of the threads
      // This was if one thread does all the work we will be fine
      for (int i = 0; i < config.num_iters + config.num_warmup_iters; ++i) {
        thread_inputs[thread_id].push_back(cloneInput(inputs_[dist(engine)]));
      }
      input_iters[thread_id] = 0;
    }
  }
 
  std::mutex m;
  std::condition_variable worker_main_cv;
  std::condition_variable main_worker_cv;
  // TODO: add GUARDED_BY once it is available
  int64_t initialized{0};
  int64_t finished{0};
  bool start{false};
  std::atomic<int64_t> num_attempted_iters{0};
  std::vector<std::thread> callers;
 
  for (auto thread_id = 0; thread_id < config.num_calling_threads;
       ++thread_id) {
    callers.emplace_back([&, thread_id]() {
      // We use conditional variable as a barrier to make sure each thread
      // performs required warmeup iterations before we start measuring
      for (auto j = 0; j < config.num_warmup_iters; ++j) {
        runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
        ++input_iters[thread_id];
      }
      {
        std::unique_lock<std::mutex> lock(m);
        ++initialized;
        worker_main_cv.notify_one();
        while (!start) {
          main_worker_cv.wait(lock);
        }
      }
      LOG(INFO) << "Starting forward thread " << thread_id;
      while (num_attempted_iters.fetch_add(1) < config.num_iters) {
        runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
        ++input_iters[thread_id];
      }
 
      {
        std::unique_lock<std::mutex> lock(m);
        ++finished;
        worker_main_cv.notify_one();
        LOG(INFO) << "Shutting down forward thread " << thread_id
                  << ". Total number of finished threads: " << finished;
      }
 
    });
  }
 
  using Clock = std::chrono::high_resolution_clock;
  using TimePoint = std::chrono::time_point<Clock>;
  TimePoint start_time;
 
  {
    std::unique_lock<std::mutex> lock(m);
    while (initialized != config.num_calling_threads) {
      worker_main_cv.wait(lock);
    }
    LOG(INFO) << "Starting threads";
    start = true;
    start_time = Clock::now();
  }
 
  main_worker_cv.notify_all();
  {
    std::unique_lock<std::mutex> lock(m);
    worker_main_cv.wait(
        lock, [&]() { return finished == config.num_calling_threads; });
  }
  auto end_time = std::chrono::high_resolution_clock::now();
  LOG(INFO) << "Finished benchmark";
 
  BenchmarkExecutionStats stats;
  float total_time_ms = std::chrono::duration_cast<std::chrono::nanoseconds>(
                            end_time - start_time)
                            .count() /
      1000.0 / 1000.0;
  // We use config.num_iters instead of num_attempted_iters as it is
  // repsesatative of the real work done. Last attempted iteration on each
  // calling threads doesn't represent the real work (i.e. running the model)
  stats.latency_avg_ms =
      total_time_ms * config.num_calling_threads / config.num_iters;
  stats.num_iters = config.num_iters;
 
  for (auto& t : callers) {
    t.join();
  }
  return stats;
}
 
} // namespace detail
} // namespace throughput_benchmark
} // namespace torch