reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-20 455a7bb3e42582a62a02d7baf1b1d4495bf6107c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
#pragma once
 
#include <torch/arg.h>
#include <torch/csrc/utils/memory.h>
#include <torch/data/datasets/stateful.h>
#include <torch/data/samplers.h>
#include <queue>
#include <thread>
 
#include <torch/serialize.h>
 
namespace torch {
namespace data {
namespace datasets {
 
/// Interface for chunk reader, which performs data chunking and reading of
/// entire chunks.
///
/// A chunk could be an entire file, such as an audio data file or an image,
/// or part of a file in the case of a large text-file split based on seek
/// positions.
template <typename ExampleType_, typename ChunkType_ = std::vector<ExampleType_>>
class ChunkDataReader {
 public:
  using ChunkType = ChunkType_;
  using ExampleType = ExampleType_;
 
  /// Read an entire chunk.
  virtual ChunkType read_chunk(size_t chunk_index) = 0;
 
  /// Returns the number of chunks available in this reader.
  virtual size_t chunk_count() = 0;
 
  /// This will clear any internal state associate with this reader.
  virtual void reset() = 0;
};
 
namespace detail {
/// BatchDataBuffer manages a queue of UnwrappedBatchData. After a new chunk is
/// loaded, BatchDataBuffer splits it into small batches and push them into the
/// queue. When get_batch is called from data loader, it pops cached batches and
/// return. If the cache is empty, it either waits to load more chunks or return
/// null if all chunks are loaded.
template <
    typename UnwrappedBatch,
    typename ExampleSampler = samplers::RandomSampler>
class BatchDataBuffer {
 public:
  using UnwrappedBatchType = UnwrappedBatch;
  using BatchType = torch::optional<UnwrappedBatchType>;
  using BatchRequestType = typename ExampleSampler::BatchRequestType;
 
  BatchDataBuffer(
      size_t batch_size,
      ExampleSampler& example_sampler,
      size_t queue_capacity)
      : batch_size_(batch_size),
        example_sampler_(example_sampler),
        queue_capacity_(queue_capacity) {}
 
  /// Return batch data from the queue. Called from the ChunkDataset main
  /// thread.
  BatchType get_batch() {
    std::unique_lock<std::mutex> lock(queue_mutex_);
    cv_read_.wait(lock, [this] {
      // wait till there is available data in the queue or if all chunks are
      // loaded (i.e. the dataset is exhausted for this epoch)
      return (
          this->total_example_count_in_queue_ >= batch_size_ ||
          this->stop_);
    });
    if (batch_queue_.empty()) {
      AT_ASSERT(stop_);
      // All batches have been retrieved. Return an empty batch.
      return nullopt;
    }
 
    UnwrappedBatchData batch = std::move(batch_queue_.front());
    batch_queue_.pop();
    if (batch.exception) {
      throw WorkerException(batch.exception);
    }
 
    total_example_count_in_queue_ -= batch.batch_data.size();
    lock.unlock();
    cv_write_.notify_all();
 
    return batch.batch_data;
  }
 
  /// Push preloaded chunks to batch queue. Called from the ChunkDataset worker
  /// threads.
  void add_chunk_data(UnwrappedBatchType data) {
    std::unique_lock<std::mutex> lock(queue_mutex_);
    cv_write_.wait(lock, [this] {
      // stop loading if we have preloaded enough data.
      return this->total_example_count_in_queue_ < this->queue_capacity_ ||
          this->stop_;
    });
    if (stop_) {
      // When stop_ is true, it means no further chunk loading is necessary.
      // Return without any further processing.
      return;
    }
 
    auto data_size = data.size();
    auto remaining_size = data_size;
    example_sampler_.reset(data_size);
 
    auto fill_batch = [&](size_t example_count, UnwrappedBatchType& batch) {
      auto batch_example_indices = this->example_sampler_.next(example_count);
      AT_ASSERT(
          batch_example_indices &&
          batch_example_indices.value().size() == example_count);
      BatchRequestType& indices = batch_example_indices.value();
      for (size_t i : indices) {
        TORCH_CHECK(i < data_size, "Index out of range");
        batch.emplace_back(std::move(data[i]));
      }
      remaining_size -= example_count;
    };
 
    if (!batch_queue_.empty()) {
      // if the queue has existing data, and the last batch doesn't have enough
      // examples to fill a batch_size batch, add more example to this batch first.
      auto& batch = batch_queue_.back();
      size_t current_count = batch.batch_data.size();
      if (current_count < batch_size_) {
        auto example_count =
            std::min(remaining_size, batch_size_ - current_count);
        fill_batch(example_count, batch.batch_data);
      }
    }
 
    // If we still have data remaining after filling the last pushed batch, add
    // them to the queue too.
    while (remaining_size > 0) {
      UnwrappedBatchType current_batch;
 
      // Allocate the batch memory ahead of time.
      current_batch.reserve(batch_size_);
 
      auto example_count = std::min(remaining_size, batch_size_);
      fill_batch(example_count, current_batch);
      batch_queue_.emplace(std::move(current_batch));
    }
    total_example_count_in_queue_ += data_size;
    lock.unlock();
    cv_read_.notify_all();
  }
 
  /// Push exceptions thrown during preloading into batch queue. Called from
  /// the ChunkDataset worker threads.
  void add_chunk_data(std::exception_ptr e_ptr) {
    std::unique_lock<std::mutex> lock(queue_mutex_);
    cv_write_.wait(lock, [this] {
      // stop loading if we have preloaded enough data.
      return (
        this->total_example_count_in_queue_ < this->queue_capacity_ ||
        this->stop_);
    });
    if (stop_){
      // When stop_ is true, it means this current thread needs to be tore down,
      // the batch buffer will be discarded, so no need to enqueue any new
      // exceptions.
      return;
    }
 
    batch_queue_.emplace(e_ptr);
    lock.unlock();
    cv_read_.notify_all();
  }
 
  void stop(){
    {
      // Hold the lock before changing stop_ to prevent a race condition which can
      // cause a deadlock.
      // To be more specific, conditional variable cv_write_ waits on predicate
      // stop_ in add_chunk_data(). The wait happens in two steps: 1) while still
      // holding the lock, check if predicate is true; 2) if it is true, proceeds,
      // otherwise, release the lock and wait until notified. Without holding a
      // lock, cv_write_'s notification can happen in between step 1) and 2). In
      // that case, as cv_write_ is not in waiting status yet, so the notification
      // is lost and cv_write_ will sleep forever.
      // By taking a lock before changing predicate stop_, it is ensured updating
      // and evaluating stop_ always happen in a synchronized way
      std::lock_guard<std::mutex> lock(queue_mutex_);
      stop_ = true;
    }
 
    // notify all writers, wake them from wait to exit current method.
    cv_write_.notify_all();
    // notify all readers too.
    cv_read_.notify_all();
  }
  /// The batch size is needed to create batches from the chunk data. Similar to
  /// regular dataloader where the batches are created with prefetches,
  /// BatchDataBuffer perform the batch creation using the provided batch size.
  size_t batch_size_ = 0;
 
  /// count of total example stored in the queue
  size_t total_example_count_in_queue_ = 0;
 
  /// struct that contains a raw unwrapped batch unit. An unwrapped batch unit is
  /// the raw data without 'optional' wrapper. It can be a collection of images,
  /// utterances, e.t.c.
  struct UnwrappedBatchData {
    explicit UnwrappedBatchData(UnwrappedBatchType data) : batch_data(std::move(data)) {}
 
    explicit UnwrappedBatchData(std::exception_ptr e) : exception(e) {}
 
    /// batch data to return
    UnwrappedBatchType batch_data;
 
    /// exception pointer which captures any abnormal exceptions while creating the
    /// batch.
    std::exception_ptr exception;
  };
 
  /// local cache to store example batches from loaded chunk
  std::queue<UnwrappedBatchData> batch_queue_;
 
  // sync batch_queue_ update.
  std::mutex queue_mutex_;
 
  std::condition_variable cv_read_;
  std::condition_variable cv_write_;
 
  ExampleSampler& example_sampler_;
 
  // configurable maximun number of elements the queue can hold at one time.
  size_t queue_capacity_;
 
  // When set to true, it wakes the writer threads from the wait and exit current
  // function call. This is needed when ChunkDataSet.Reset is called while the
  // previous epoch is not exhausted yet. When ChunkDataset is waiting its
  // preloader to finish previous work before tearing down the thread, the
  // preloader could be still waiting for the conditional variable, thus cause
  // the program to hang. This boolean is used to break this waiting condition.
  bool stop_ = false;
};
} // namespace detail
 
/// Options to configure a `ChunkDataset`.
struct ChunkDatasetOptions {
  ChunkDatasetOptions() = delete;
  ChunkDatasetOptions(
      size_t preloader_count,
      size_t batch_size,
      size_t cache_size = 2048,
      size_t cross_chunk_shuffle_count = 1)
      : preloader_count_(preloader_count),
        batch_size_(batch_size),
        cache_size_(cache_size),
        cross_chunk_shuffle_count_(cross_chunk_shuffle_count) {
    TORCH_CHECK(
        preloader_count_ > 0,
        "Preloader count is 0. At least one preloader needs to be specified.");
    TORCH_CHECK(
        batch_size_ > 0,
        "Batch size is 0. A positive batch size needs to be specified.");
    TORCH_CHECK(
        cache_size_ > 0,
        "Cache size is 0. A positive cache size needs to be specified.");
    TORCH_CHECK(
        cache_size_ >= batch_size_,
        "Cache size is less than batch size. Cache needs to be large enough to "
        "hold at least one batch.");
    TORCH_CHECK(
        cross_chunk_shuffle_count_ > 0,
        "cross_chunk_shuffle_count needs to be greater than 0.");
  }
 
  /// The number of worker thread to preload chunk data.
  TORCH_ARG(size_t, preloader_count);
 
  /// The size of each batch.
  TORCH_ARG(size_t, batch_size);
 
  /// The capacity of the queue for batch caching.
  TORCH_ARG(size_t, cache_size) = 2048;
 
  // The number of chunks to perfrom cross-chunk shuffling. Default to 1 meaning
  // no cross-chunk shuffling. When it is equal to n (n > 1), n random
  // chunks will be loaded at once and example shuffling will be performed
  // across all those n chunks.
  // Note: Usually the default config (1 chunk shuffle + example shuffle) is
  // good enough to generate random distributed data. Use this parameter only if
  // you know cross-shuffle is needed in your case. Also there is a performance
  // penalty when this value is greater than 1, as we need to do extra merge
  // between multiple chunks before performing example sampling.
  TORCH_ARG(size_t, cross_chunk_shuffle_count) = 1;
};
 
/// A stateful dataset that support hierarchical sampling and prefetching of
/// entre chunks.
///
/// Unlike regular dataset, chunk dataset require two samplers to operate and
/// keeps an internal state. `ChunkSampler` selects, which chunk to load next,
/// while the `ExampleSampler` determins the order of Examples that are returned
/// in each `get_batch` call. The hierarchical sampling approach used here is
/// inspired by this paper http://martin.zinkevich.org/publications/nips2010.pdf
template <
    typename ChunkReader,
    typename ChunkSampler = samplers::RandomSampler,
    typename ExampleSampler = samplers::RandomSampler>
class ChunkDataset final
    : public StatefulDataset<
          ChunkDataset<ChunkReader, ChunkSampler, ExampleSampler>,
          typename ChunkReader::BatchType,
          size_t> {
 public:
  using BatchType = torch::optional<typename ChunkReader::BatchType>;
  using UnwrappedBatchType = typename ChunkReader::BatchType;
  using BatchRequestType = size_t;
  using ChunkSamplerType = ChunkSampler;
  using ExampleSamplerType = ExampleSampler;
 
  ChunkDataset(
      ChunkReader chunk_reader,
      ChunkSampler chunk_sampler,
      ExampleSampler example_sampler,
      ChunkDatasetOptions options,
      std::function<void(UnwrappedBatchType&)> preprocessing_policy =
          std::function<void(UnwrappedBatchType&)>())
      : chunk_reader_(std::move(chunk_reader)),
        chunk_sampler_(std::move(chunk_sampler)),
        example_sampler_(std::move(example_sampler)),
        options_(std::move(options)),
        preprocessing_policy_(preprocessing_policy),
        quit_worker_(false),
        running_preloaders_(0),
        load_checkpoint_(false) {}
 
  virtual ~ChunkDataset() {
    // stop batch buffer first.
    if (batch_buffer_) {
      batch_buffer_->stop();
    }
    free_workers();
  }
 
  /// Default get_batch method of BatchDataset. This method returns
  /// Example batches created from the preloaded chunks. The implemenation
  /// is dataset agnostic and does not need overriding in different chunk
  /// datasets.
  BatchType get_batch(size_t batch_size) override {
    TORCH_CHECK(
      batch_buffer_ != nullptr,
      "Dataset needs to call reset() before calling get_batch().");
 
    TORCH_CHECK(
      batch_size == options_.batch_size(),
      "The requested batch size does not match with the initialized batch size.\n"
      " The requested batch size is ", batch_size,
      ", while the dataset is created with batch size equal to ", options_.batch_size());
    return batch_buffer_->get_batch();
  }
 
  /// Helper method around get_batch as `batch_size` is not strictly necessary
  BatchType get_batch() {
    return get_batch(options_.batch_size());
  }
 
  /// This will clear any internal state and starts the internal prefetching
  /// mechanism for the chunk dataset.
  void reset() override {
    // We need this to support partial data reads via dataloader iterator.
    if (batch_buffer_) {
      batch_buffer_->stop();
    }
    // free workers from previous reset if there is any.
    free_workers();
    preload_threads_.clear();
 
    if (!load_checkpoint_){
      chunk_reader_.reset();
      chunk_sampler_.reset(chunk_reader_.chunk_count());
      load_checkpoint_ = false;
    }
 
    // Throw out any existing cached batch in the buffer and re-creates a new
    // chunk buffer.
    batch_buffer_ = torch::make_unique<
        detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>(
        options_.batch_size(),
        example_sampler_,
        options_.cache_size());
 
    // create new workers for this new epoch.
    quit_worker_ = false;
 
    AT_ASSERT(running_preloaders_ == 0);
    running_preloaders_ = options_.preloader_count();
    for (size_t i = 0; i < options_.preloader_count(); ++i) {
      preload_threads_.emplace_back([this, i]() { this->preloader(i); });
    }
  }
 
  /// size is not used for chunk dataset.
  optional<size_t> size() const override {
    return torch::nullopt;
  }
 
  // provide a references to chunk sampler. Used mainly in distributed data
  // loading to set the epoch number for the sampler.
  ChunkSamplerType& chunk_sampler() {
    return chunk_sampler_;
  }
 
  void save(serialize::OutputArchive& archive) const override {
    std::lock_guard<std::mutex> lock(chunk_index_guard_);
    chunk_sampler_.save(archive);
  }
 
  void load(serialize::InputArchive& archive) override{
    std::lock_guard<std::mutex> lock(chunk_index_guard_);
    chunk_sampler_.load(archive);
    load_checkpoint_ = true;
  }
 
 private:
  /// running on worker thread to preload chunk data.
  void preloader(size_t id) {
    while (!quit_worker_.load()) {
      try {
        std::vector<size_t> chunk_idx;
        {
          std::lock_guard<std::mutex> lock(chunk_index_guard_);
          if (auto chunk_sampler_result = chunk_sampler_.next(this->options_.cross_chunk_shuffle_count())) {
            chunk_idx = chunk_sampler_result.value();
          } else {
            break;
          }
        }
        UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_idx[0]);
        for (size_t i = 1; i < chunk_idx.size(); ++i) {
          auto chunk_data = chunk_reader_.read_chunk(chunk_idx[i]);
          std::move(
              chunk_data.begin(), chunk_data.end(), std::back_inserter(data));
        }
        if (preprocessing_policy_) {
          preprocessing_policy_(data);
        }
        if (!data.empty()) { // skip empty chunks.
          batch_buffer_->add_chunk_data(std::move(data));
        }
      } catch (...) {
        batch_buffer_->add_chunk_data(std::current_exception());
      }
    }
    AT_ASSERT(running_preloaders_.load() > 0);
    --running_preloaders_;
    if (running_preloaders_.load() == 0) {
      // all preloaders are completed, so we can notify the batch_buffer.
      batch_buffer_->stop();
    }
  }
 
  /// Block the current thread until the workers finish execution and exit.
  void free_workers() {
    if (!quit_worker_.load()) {
      quit_worker_ = true;
      for (auto& worker_thread : preload_threads_) {
        worker_thread.join();
      }
    }
  }
 
 private:
  // Templated class that defines what is a chunk and how to read chunk data.
  // When a chunk is returned by chunk_reader_, ChunkDataset split it into
  // batches and caches them in batch_buffer_.
  ChunkReader chunk_reader_;
 
  // chunk sampler to shuffle different chunks
  ChunkSamplerType chunk_sampler_;
 
  // example sampler to shuffle examples in a specific chunk
  ExampleSamplerType example_sampler_;
 
  // batch data buffer which holds chunk data from preloading thread.
  std::shared_ptr<detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>
      batch_buffer_;
 
  // worker thread pool
  std::vector<std::thread> preload_threads_;
 
  /// The options the Dataset was configured with.
  const ChunkDatasetOptions options_;
 
  // function pointer wrapper to apply custom processing over chunk data. This is
  // considered an advanced parameter for developers who want to apply a
  // pre-process to the chunk data before sampling into minibatch.
  // Different than the collate function, this policy is applied on the chunk
  // level, instead of minibatch level. When a chunk of data is loaded (multiple
  // chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is
  // applied to the full loaded data. It is useful if developers want to
  // perform pre-processing (like bucketing) to the chunk data before
  // example sampler samples the data. By default it's an empty pointer and no
  // action will be taken.
  std::function<void(UnwrappedBatchType&)> preprocessing_policy_;
 
  // indicate whether the worker thread can be teared down
  std::atomic<bool> quit_worker_;
 
  // keep track of running preloaders to notify batch buffer. A value 0
  // indicates that the chunk loading is completed.
  std::atomic<size_t> running_preloaders_;
 
  // mutex to synchronize chunk sampler next() call.
  mutable std::mutex chunk_index_guard_;
 
  // boolean value to indicate whether we need to load the checkpoint for chunk_sampler_.
  bool load_checkpoint_;
};
} // namespace datasets
} // namespace data
} // namespace torch