reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-14 1e0d1e8caa7790c036b36a7ca62261f3625bb09c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#pragma once
 
#include <torch/data/dataloader/base.h>
 
#include <cstddef>
#include <thread>
#include <utility>
 
namespace torch {
namespace data {
 
/// A dataloader for stateful datasets.
///
/// A dataloader for stateful datatasets differs from one for stateless
/// datasets one in that the dataset is shared among worker threads, and that
/// this dataset is itself responsible for producing batches rather than
/// depending on a sampler. The statefulness here actually refers to the
/// dataset. The StatefulDataLoader simply alters the data loading algorithm to
/// accomodate the stateful, shared nature of the dataset. Note that the dataset
/// must be thread safe if more than one worker thread is used.
///
/// A stateful dataloader is created by calling `make_data_loader` with a
/// stateful dataset.
template <typename Dataset>
class StatefulDataLoader : public DataLoaderBase<
                               Dataset,
                               typename Dataset::BatchType::value_type,
                               typename Dataset::BatchRequestType> {
 public:
  using super = DataLoaderBase<
      Dataset,
      typename Dataset::BatchType::value_type,
      typename Dataset::BatchRequestType>;
  using typename super::BatchRequestType;
 
  /// Constructs the `StatefulDataLoader` from a `dataset` and some `options`.
  StatefulDataLoader(Dataset dataset, DataLoaderOptions options)
      : super(
            std::move(options),
            torch::make_unique<Dataset>(std::move(dataset))) {
    for (size_t w = 0; w < this->options_.workers; ++w) {
      // As opposed to the stateless case, here all worker threads access the
      // same underlying dataset.
      this->workers_.emplace_back(
          [this] { this->worker_thread(*this->main_thread_dataset_); });
    }
  }
 
 private:
  /// Resets the internal state of the dataloader and the dataset.
  void reset() override {
    this->main_thread_dataset_->reset();
    // Call the base class method last because it calls `prefetch()`
    super::reset();
  }
 
  /// For stateful datasets, the batch request is always the batch size. The
  /// dataset is responsible for determining what goes into the batch next.
  optional<BatchRequestType> get_batch_request() override {
    return this->options_.batch_size;
  }
};
} // namespace data
} // namespace torch