reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-16 a47fccb11fa3470901aebcb27f861d242d0925e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#pragma once
 
#include <torch/types.h>
 
#include <algorithm>
#include <cstddef>
#include <vector>
 
namespace torch {
namespace data {
namespace detail {
namespace sequencers {
namespace detail {
template<typename Result>
bool buffer_contains_result(const std::vector<optional<Result>>& buffer) {
  return std::any_of(
      buffer.begin(), buffer.end(), [](const optional<Result>& result) {
        return result.has_value();
      });
}
} // namespace detail
 
/// A `Sequencer` accepts a function that yields the next result of a
/// `DataLoader` and then has the opportunity to influence the order in which
/// these results are returned. The `NoSequencer` does not enforce any
/// sequencing and returns any result directly. The `OrderedSequencer` instead
/// buffers results internally to return them in order of their sequence number.
template <typename Result>
struct Sequencer {
  using ResultProducer = std::function<optional<Result>()>;
  virtual ~Sequencer() = default;
  virtual optional<Result> next(ResultProducer next_result) = 0;
};
 
/// A `Sequencer` that does not enforce any ordering. It is effectively the
/// identity function.
template <typename Result>
struct NoSequencer final : public Sequencer<Result> {
  using typename Sequencer<Result>::ResultProducer;
  optional<Result> next(ResultProducer next_result) override {
    return next_result();
  }
};
 
/// A `Sequencer` that buffers results and returns them in order of their
/// sequence number. The `OrderedSequencer` maintains an internal, monotonically
/// incrementing counter for the next sequence number it expects. If it receives
/// a result with a higher sequence number, it will buffer it for later (when
/// the sequence number reaches that of this result). Otherwise, if the sequence
/// numbers match, the result is returned.
///
/// Implementation note: The `OrderedSequencer` is implemented with a fixed-size
/// buffer. Let `m` be the maximum number of jobs in the data loader's queue and
/// `s` be the current sequence number. Assume `m` jobs are scheduled in the
/// `DataLoader`. Any new result is stored at index `job.sqn mod m` in the
/// `OrderedSequencer`. Why are we sure sequence numbers of new jobs will not
/// collide with sequence numbers of buffered jobs? The `OrderedSequencer` will
/// not return from `next()` until it receives the result with sqn `s`. This
/// means no new jobs can be scheduled in the `DataLoader` in the meantime,
/// which enforces that as long as sqn `s` has not been received, `s + m` (which
/// would cause a collision in the fixed-size buffer) will not yet be scheduled.
template <typename Result>
struct OrderedSequencer : public Sequencer<Result> {
  using typename Sequencer<Result>::ResultProducer;
 
  /// Constructs the `OrderedSequencer` with the maximum number of results it
  /// will ever hold at one point in time.
  explicit OrderedSequencer(size_t max_jobs) : buffer_(max_jobs) {}
 
  /// Buffers results until the next one in the expected order is received.
  optional<Result> next(ResultProducer next_result) override {
    // If we already have the result for the next sqn, return it.
    if (auto& maybe_result = buffer(next_sequence_number_)) {
      auto result = std::move(*maybe_result);
      buffer(next_sequence_number_++).reset();
      return result;
    }
    // Otherwise wait for the next result.
    while (true) {
      auto result = next_result();
      if (!result) {
        AT_ASSERT(!detail::buffer_contains_result(buffer_));
        break;
      }
      // If it was not nullopt and the sequence numbers match, return it
      // directly and bump the sequence number.
      if (result->sequence_number == next_sequence_number_) {
        ++next_sequence_number_;
        return result;
      }
      // Stash the result for later.
      AT_ASSERT(!buffer(result->sequence_number).has_value());
      buffer(result->sequence_number) = std::move(result);
    }
    // The result was an empty optional, so we are done with this epoch.
    return nullopt;
  }
 
  /// Accesses the buffer at the `index` modulo the buffer size.
  optional<Result>& buffer(size_t index) {
    return buffer_.at(index % buffer_.size());
  }
 
  /// The monotonically increasing sequence number we expect.
  size_t next_sequence_number_ = 0;
 
  /// A fixed-size buffer (after construction).
  std::vector<optional<Result>> buffer_;
};
} // namespace sequencers
} // namespace detail
} // namespace data
} // namespace torch