reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-02-28 27bef7116852ea5e165bfe454b86345bd57a16ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#pragma once
 
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/data/samplers/base.h>
 
#include <cstddef>
#include <vector>
 
namespace torch {
namespace serialize {
class OutputArchive;
class InputArchive;
} // namespace serialize
} // namespace torch
 
namespace torch {
namespace data {
namespace samplers {
 
/// A `Sampler` that selects a subset of indices to sample from and defines a
/// sampling behavior. In a distributed setting, this selects a subset of the
/// indices depending on the provided num_replicas and rank parameters. The
/// `Sampler` performs a rounding operation based on the `allow_duplicates`
/// parameter to decide the local sample count.
template <typename BatchRequest = std::vector<size_t>>
class DistributedSampler : public Sampler<BatchRequest> {
 public:
  DistributedSampler(
      size_t size,
      size_t num_replicas = 1,
      size_t rank = 0,
      bool allow_duplicates = true)
      : size_(size),
        num_replicas_(num_replicas),
        rank_(rank),
        epoch_(0),
        allow_duplicates_(allow_duplicates) {}
 
  /// Set the epoch for the current enumeration. This can be used to alter the
  /// sample selection and shuffling behavior.
  void set_epoch(size_t epoch) {
    epoch_ = epoch;
  }
 
  size_t epoch() const {
    return epoch_;
  }
 
 protected:
  size_t local_sample_count() {
    if (allow_duplicates_) {
      return (size_ + num_replicas_ - 1) / num_replicas_;
    } else {
      return size_ / num_replicas_;
    }
  }
 
  size_t size_;
  size_t num_replicas_;
  size_t rank_;
  size_t epoch_;
  bool allow_duplicates_;
};
 
/// Select samples randomly. The sampling order is shuffled at each `reset()`
/// call.
class TORCH_API DistributedRandomSampler : public DistributedSampler<> {
 public:
  DistributedRandomSampler(
      size_t size,
      size_t num_replicas = 1,
      size_t rank = 0,
      bool allow_duplicates = true);
 
  /// Resets the `DistributedRandomSampler` to a new set of indices.
  void reset(optional<size_t> new_size = nullopt) override;
 
  /// Returns the next batch of indices.
  optional<std::vector<size_t>> next(size_t batch_size) override;
 
  /// Serializes the `DistributedRandomSampler` to the `archive`.
  void save(serialize::OutputArchive& archive) const override;
 
  /// Deserializes the `DistributedRandomSampler` from the `archive`.
  void load(serialize::InputArchive& archive) override;
 
  /// Returns the current index of the `DistributedRandomSampler`.
  size_t index() const noexcept;
 
 private:
  void populate_indices();
 
  size_t begin_index_;
  size_t end_index_;
  size_t sample_index_;
  std::vector<size_t> all_indices_;
};
 
/// Select samples sequentially.
class TORCH_API DistributedSequentialSampler : public DistributedSampler<> {
 public:
  DistributedSequentialSampler(
      size_t size,
      size_t num_replicas = 1,
      size_t rank = 0,
      bool allow_duplicates = true);
 
  /// Resets the `DistributedSequentialSampler` to a new set of indices.
  void reset(optional<size_t> new_size = nullopt) override;
 
  /// Returns the next batch of indices.
  optional<std::vector<size_t>> next(size_t batch_size) override;
 
  /// Serializes the `DistributedSequentialSampler` to the `archive`.
  void save(serialize::OutputArchive& archive) const override;
 
  /// Deserializes the `DistributedSequentialSampler` from the `archive`.
  void load(serialize::InputArchive& archive) override;
 
  /// Returns the current index of the `DistributedSequentialSampler`.
  size_t index() const noexcept;
 
 private:
  void populate_indices();
 
  size_t begin_index_;
  size_t end_index_;
  size_t sample_index_;
  std::vector<size_t> all_indices_;
};
 
} // namespace samplers
} // namespace data
} // namespace torch