reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-02-28 27bef7116852ea5e165bfe454b86345bd57a16ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#pragma once
 
#include <torch/data/dataloader/base.h>
#include <torch/data/worker_exception.h>
 
#include <torch/csrc/utils/memory.h>
 
#include <c10/util/Exception.h>
 
#include <cstddef>
#include <thread>
#include <utility>
 
namespace torch {
namespace data {
 
/// A dataloader for stateless datasets.
///
/// This dataloader follows the traditional PyTorch dataloader design, whereby a
/// (posssibly) stateful sampler produces *batch requests* for a stateless
/// dataset, which acts as a simple batch request to batch mapping. The batch
/// request will often be an array of indices, and if the dataset is a simple
/// image dataset, the dataset would produce the images at those indices.
template <typename Dataset, typename Sampler>
class StatelessDataLoader : public DataLoaderBase<
                                Dataset,
                                typename Dataset::BatchType,
                                typename Sampler::BatchRequestType> {
 public:
  using super = DataLoaderBase<
      Dataset,
      typename Dataset::BatchType,
      typename Sampler::BatchRequestType>;
  using typename super::BatchRequestType;
 
  /// Constructs the `StatelessDataLoader` from a `dataset`, a `sampler` and
  /// some `options`.
  StatelessDataLoader(
      Dataset dataset,
      Sampler sampler,
      DataLoaderOptions options)
      : super(std::move(options)), sampler_(std::move(sampler)) {
    for (size_t w = 0; w < this->options_.workers; ++w) {
      // Here we copy the dataset into the worker thread closure. Each worker
      // has its own copy of the dataset. This means the dataset must be
      // trivially copiable, or else we don't expect more than one worker to
      // be in use.
      this->workers_.emplace_back(
          [this, dataset]() mutable { this->worker_thread(dataset); });
    }
    if (this->options_.workers == 0) {
      this->main_thread_dataset_ =
          torch::make_unique<Dataset>(std::move(dataset));
    }
  }
 
 private:
  /// Resets the internal state of the dataloader and the sampler.
  void reset() override {
    sampler_.reset();
    // Call the base class method last because it calls `prefetch()`
    super::reset();
  }
 
  /// Queries the sampler for the next batch request (possibly progressing its
  /// internal state).
  optional<BatchRequestType> get_batch_request() override {
    auto indices = sampler_.next(this->options_.batch_size);
    if (!indices ||
        (indices->size() < this->options_.batch_size &&
         this->options_.drop_last)) {
      return nullopt;
    }
    AT_ASSERT(indices->size() > 0);
    return indices;
  }
 
  /// The `Sampler` used to produce batch requests.
  Sampler sampler_;
};
} // namespace data
} // namespace torch