reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-02-28 27bef7116852ea5e165bfe454b86345bd57a16ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
#pragma once
 
#include <torch/data/dataloader_options.h>
#include <torch/data/detail/data_shuttle.h>
#include <torch/data/detail/sequencers.h>
#include <torch/data/iterator.h>
#include <torch/data/samplers/random.h>
#include <torch/data/worker_exception.h>
#include <torch/types.h>
 
#include <torch/csrc/utils/memory.h>
#include <torch/csrc/utils/variadic.h>
 
#include <c10/util/Exception.h>
 
#include <cstddef>
#include <exception>
#include <memory>
#include <thread>
#include <type_traits>
#include <utility>
#include <vector>
 
namespace torch {
namespace data {
template <typename Dataset, typename Batch, typename BatchRequest>
class DataLoaderBase {
 public:
  using BatchType = Batch;
  using BatchRequestType = BatchRequest;
 
  /// Constructs a new DataLoader from a `dataset` to sample from, `options`
  /// to configure the DataLoader with, and a `sampler` that specifies the
  /// sampling strategy.
  DataLoaderBase(
      DataLoaderOptions options,
      std::unique_ptr<Dataset> main_thread_dataset = nullptr)
      : options_(std::move(options)),
        main_thread_dataset_(std::move(main_thread_dataset)),
        sequencer_(new_sequencer()) {}
 
  virtual ~DataLoaderBase() {
    join();
  }
 
  /// Returns an iterator into the DataLoader. The lifetime of the iterator is
  /// bound to the DataLoader. In C++ standards language, the category of the
  /// iterator is `OutputIterator`. See
  /// https://en.cppreference.com/w/cpp/named_req/OutputIterator for what this
  /// means. In short: you may increment the iterator and dereference it, but
  /// cannot go back, or step forward more than one position at a time. When the
  /// DataLoader is exhausted, it will compare equal with the special
  /// "sentinel" iterator returned by `DataLoader::end()`. Most of the time, you
  /// should only use range-for loops to loop over the DataLoader, but
  /// standard algorithms like `std::copy(dataloader.begin(), dataloader.end(),
  /// output_iterator)`  are supported too.
  Iterator<Batch> begin() {
    TORCH_CHECK(
        shuttle_.in_flight_jobs() == 0,
        "Attempted to get a new DataLoader iterator "
        "while another iterator is not yet exhausted");
    reset();
    return Iterator<Batch>(torch::make_unique<detail::ValidIterator<Batch>>(
        [this] { return this->next(); }));
  }
 
  /// Returns a special "sentinel" iterator that compares equal with a
  /// non-sentinel iterator once the DataLoader is exhausted.
  Iterator<Batch> end() {
    return Iterator<Batch>(
        torch::make_unique<detail::SentinelIterator<Batch>>());
  }
 
  /// Joins the DataLoader's worker threads and drains internal queues.
  /// This function may only be invoked from the main thread (in which the
  /// DataLoader lives).
  void join() {
    if (joined_) {
      return;
    }
    shuttle_.drain();
    // Send one 'quit' message per worker. Since a worker dies (exits its
    // thread) after receiving this message, each `QuitWorker()` message will be
    // read by exactly one worker.
    for (size_t w = 0; w < options_.workers; ++w) {
      push_job(QuitWorker());
    }
    for (auto& worker : workers_) {
      worker.join();
    }
    joined_ = true;
  }
 
  /// Returns the options with which the DataLoader was configured.
  const FullDataLoaderOptions& options() const noexcept {
    return options_;
  }
 
 protected:
  /// Simple mix-in to give something a sequence number.
  struct Sequenced {
    Sequenced() = default;
    Sequenced(size_t sqn) : sequence_number(sqn) {}
    size_t sequence_number;
  };
 
  struct QuitWorker {};
 
  /// A `Job` is either a `BatchRequest` (new indices to fetch data at) or a
  /// `QuitWorker` object, to indicate the worker should shut down.
  struct Job : Sequenced {
    Job() = default;
    Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {}
    Job(BatchRequest&& i, size_t sqn)
        : Sequenced(sqn), batch_request(std::move(i)) {}
    optional<QuitWorker> quit;
    optional<BatchRequest> batch_request;
  };
 
  /// The finished result of a job.
  struct Result : Sequenced {
    Result() = default;
    Result(optional<Batch>&& b, size_t sqn)
        : Sequenced(sqn), batch(std::move(b)) {}
    Result(std::exception_ptr exception, size_t sqn)
        : Sequenced(sqn), exception(std::move(exception)) {}
    optional<Batch> batch;
    std::exception_ptr exception;
  };
 
  /// Subclass hook for getting the next batch request. The stateless case will
  /// ask the sampler for a new batch request (e.g. a vector of indices), while
  /// the stateful one will simply return the batch size.
  virtual optional<BatchRequestType> get_batch_request() = 0;
 
  /// Resets the internal state of the DataLoader, optionally pre-fetching
  /// new jobs.
  virtual void reset() {
    shuttle_.drain();
    sequence_number_ = 0;
    sequencer_ = new_sequencer();
    prefetch();
  }
 
  /// Schedules `requested_jobs` many new batches to be fetched. The actual
  /// number of jobs scheduled may be less if the DataLoader exhausts.
  void prefetch(size_t requested_jobs) {
    for (size_t r = 0; r < requested_jobs; ++r) {
      if (auto batch_request = get_batch_request()) {
        this->push_job(std::move(*batch_request));
      } else {
        break;
      }
    }
  }
 
  /// Schedules the maximum number of jobs (based on the `max_jobs` option).
  void prefetch() {
    prefetch(options_.max_jobs);
  }
 
  /// Returns the next batch of data, or an empty `optional` if the DataLoader
  /// is exhausted. This operation will block until a batch is available if one
  /// is still expected.
  optional<BatchType> next() {
    if (options_.workers > 0) {
      while (optional<Result> result = this->pop_result()) {
        if (result->exception) {
          throw WorkerException(result->exception);
        } else if (result->batch) {
          prefetch(1);
          return std::move(result->batch);
        }
      }
    } else if (auto batch_request = get_batch_request()) {
      return this->main_thread_dataset_->get_batch(std::move(*batch_request));
    }
    return nullopt;
  }
 
  /// The function that worker threads run.
  void worker_thread(Dataset& dataset) {
    while (true) {
      auto job = shuttle_.pop_job();
      if (job.quit) {
        break;
      }
      try {
        auto batch = dataset.get_batch(std::move(*job.batch_request));
        shuttle_.push_result({std::move(batch), job.sequence_number});
      } catch (...) {
        shuttle_.push_result({std::current_exception(), job.sequence_number});
      }
    }
  }
 
  /// Convenience method that calls `shuttle_.push_job()` with the next sequence
  /// number.
  template <typename T>
  void push_job(T value) {
    shuttle_.push_job({std::move(value), sequence_number_++});
  }
 
  /// Convenience method that gets the next result from the sequencer.
  optional<Result> pop_result() {
    return sequencer_->next(
        [this] { return this->shuttle_.pop_result(this->options_.timeout); });
  }
 
  /// Convenience method that creates a new sequencer based on the
  /// `enforce_ordering` option.
  std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer() {
    if (options_.enforce_ordering) {
      return torch::make_unique<detail::sequencers::OrderedSequencer<Result>>(
          options_.max_jobs);
    }
    return torch::make_unique<detail::sequencers::NoSequencer<Result>>();
  }
 
  /// The options the DataLoader was configured with.
  const FullDataLoaderOptions options_;
 
  /// The dataset for the main thread, only has a value if the number of
  /// worker threads was configured as zero, meaning the main thread has to do
  /// all the work (synchronously). NOTE: Really want this to be on the heap
  /// when empty, therefore `unique_ptr` and not `optional`.
  std::unique_ptr<Dataset> main_thread_dataset_;
 
  /// The sequence number for the *next* batch to be retrieved from the
  /// dataset.
  size_t sequence_number_ = 0;
 
  /// The worker threads, running the `worker_thread()` method.
  std::vector<std::thread> workers_;
 
  /// The `DataShuttle` which takes care of the life cycle of a job.
  detail::DataShuttle<Job, Result> shuttle_;
 
  /// The `Sequencer`, which handles optional ordering of batches.
  std::unique_ptr<detail::sequencers::Sequencer<Result>> sequencer_;
 
  /// True if the DataLoader has joined its worker threads.
  bool joined_ = false;
};
} // namespace data
} // namespace torch