reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-10 c3765bd24fe73747688a0ec2a550f219c9acb384
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#pragma once
 
#include <torch/data/dataloader/stateful.h>
#include <torch/data/dataloader/stateless.h>
 
#include <torch/csrc/utils/memory.h>
#include <torch/csrc/utils/variadic.h>
 
#include <c10/util/Exception.h>
 
#include <cstddef>
#include <memory>
#include <type_traits>
#include <utility>
 
namespace torch {
namespace data {
 
/// Creates a `DataLoader` instance for a stateless `dataset`, a `sampler` and
/// some `options`.
template <typename Dataset, typename Sampler>
torch::disable_if_t<
    Dataset::is_stateful,
    std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) {
  return torch::make_unique<StatelessDataLoader<Dataset, Sampler>>(
      std::move(dataset), std::move(sampler), std::move(options));
}
 
/// Creates a `DataLoader` instance for a stateless `dataset` and some
/// `options`. A sampler (by default a `RandomSampler`) will be constructed from
/// the size of the dataset.
template <typename Sampler = samplers::RandomSampler, typename Dataset>
torch::disable_if_t<
    Dataset::is_stateful || !std::is_constructible<Sampler, size_t>::value,
    std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
make_data_loader(
    Dataset dataset,
    DataLoaderOptions options = DataLoaderOptions()) {
  const optional<size_t> size = dataset.size();
  TORCH_CHECK(
      size.has_value(),
      "Expected the dataset to be sized in "
      "order to construct the Sampler");
  return make_data_loader(
      std::move(dataset), Sampler(*size), std::move(options));
}
 
/// Creates a `DataLoader` for a stateful `dataset` and some `options`.
template <typename Dataset, typename = torch::enable_if_t<Dataset::is_stateful>>
std::unique_ptr<StatefulDataLoader<Dataset>> make_data_loader(
    Dataset dataset,
    DataLoaderOptions options = DataLoaderOptions()) {
  return torch::make_unique<StatefulDataLoader<Dataset>>(
      std::move(dataset), std::move(options));
}
} // namespace data
} // namespace torch