reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-11 bdf3ad71583fb4ef100d3819ecdae8fd9f70083e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#pragma once
 
#include <torch/data/datasets/base.h>
 
#include <memory>
#include <utility>
 
namespace torch {
namespace data {
namespace datasets {
 
/// A dataset that wraps another dataset in a shared pointer and implements the
/// `BatchDataset` API, delegating all calls to the shared instance. This is
/// useful when you want all worker threads in the dataloader to access the same
/// dataset instance. The dataset must take care of synchronization and
/// thread-safe access itself.
///
/// Use `torch::data::datasets::make_shared_dataset()` to create a new
/// `SharedBatchDataset` like you would a `std::shared_ptr`.
template <typename UnderlyingDataset>
class SharedBatchDataset : public BatchDataset<
                               SharedBatchDataset<UnderlyingDataset>,
                               typename UnderlyingDataset::BatchType,
                               typename UnderlyingDataset::BatchRequestType> {
 public:
  using BatchType = typename UnderlyingDataset::BatchType;
  using BatchRequestType = typename UnderlyingDataset::BatchRequestType;
 
  /// Constructs a new `SharedBatchDataset` from a `shared_ptr` to the
  /// `UnderlyingDataset`.
  /* implicit */ SharedBatchDataset(
      std::shared_ptr<UnderlyingDataset> shared_dataset)
      : dataset_(std::move(shared_dataset)) {}
 
  /// Calls `get_batch` on the underlying dataset.
  BatchType get_batch(BatchRequestType request) override {
    return dataset_->get_batch(std::move(request));
  }
 
  /// Returns the `size` from the underlying dataset.
  optional<size_t> size() const override {
    return dataset_->size();
  }
 
  /// Accesses the underlying dataset.
  UnderlyingDataset& operator*() {
    return *dataset_;
  }
 
  /// Accesses the underlying dataset.
  const UnderlyingDataset& operator*() const {
    return *dataset_;
  }
 
  /// Accesses the underlying dataset.
  UnderlyingDataset* operator->() {
    return dataset_.get();
  }
 
  /// Accesses the underlying dataset.
  const UnderlyingDataset* operator->() const {
    return dataset_.get();
  }
 
  /// Calls `reset()` on the underlying dataset.
  void reset() {
    dataset_->reset();
  }
 
 private:
  std::shared_ptr<UnderlyingDataset> dataset_;
};
 
/// Constructs a new `SharedBatchDataset` by creating a
/// `shared_ptr<UnderlyingDatase>`. All arguments are forwarded to
/// `make_shared<UnderlyingDataset>`.
template <typename UnderlyingDataset, typename... Args>
SharedBatchDataset<UnderlyingDataset> make_shared_dataset(Args&&... args) {
  return std::make_shared<UnderlyingDataset>(std::forward<Args>(args)...);
}
} // namespace datasets
} // namespace data
} // namespace torch