reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-10 c3765bd24fe73747688a0ec2a550f219c9acb384
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#pragma once
 
#include <torch/csrc/utils/variadic.h>
#include <torch/types.h>
 
#include <c10/util/Exception.h>
 
#include <functional>
#include <iterator>
#include <memory>
#include <type_traits>
#include <utility>
 
namespace torch {
namespace data {
namespace detail {
// For increased safety and more separated logic, this implementation of
// `Iterator` consists of a `ValidIterator` and a `SentinelIterator`. A
// `ValidIterator` yields new batches until the `DataLoader` is exhausted. While
// the `DataLoader` is not exhausted, `ValidIterator`s compare equal if they are
// the same object. When the `ValidIterator` becomes exhauted, it compares equal
// to the `SentinelIterator`, but not before. Half the code here is to implement
// double dispatch for the comparison. Got damnit, C++.
 
template <typename Batch>
struct ValidIterator;
 
template <typename Batch>
struct SentinelIterator;
 
/// Base class for the `ValidIterator` and `SentinelIterator`
template <typename Batch>
struct IteratorImpl {
  virtual ~IteratorImpl() = default;
  virtual void next() = 0;
  virtual Batch& get() = 0;
  virtual bool operator==(const IteratorImpl& other) const = 0;
  virtual bool operator==(const ValidIterator<Batch>& other) const = 0;
  virtual bool operator==(const SentinelIterator<Batch>& other) const = 0;
};
 
template <typename Batch>
struct ValidIterator : public IteratorImpl<Batch> {
  using BatchProducer = std::function<optional<Batch>()>;
 
  explicit ValidIterator(BatchProducer next_batch)
      : next_batch_(std::move(next_batch)) {}
 
  /// Fetches the next batch.
  void next() override {
    // If we didn't get the very first batch yet, get it now.
    lazy_initialize();
    TORCH_CHECK(
        batch_.has_value(), "Attempted to increment iterator past the end");
    // Increment to the next batch.
    batch_ = next_batch_();
  }
 
  /// Returns the current batch. The precondition for this operation to not
  /// throw an exception is that it has been compared to the `SentinelIterator`
  /// and did not compare equal.
  Batch& get() override {
    // If we didn't get the very first batch yet, get it now.
    lazy_initialize();
    TORCH_CHECK(
        batch_.has_value(),
        "Attempted to dereference iterator that was past the end");
    return batch_.value();
  }
 
  /// Does double dispatch.
  bool operator==(const IteratorImpl<Batch>& other) const override {
    return other == *this;
  }
 
  /// A `ValidIterator` is equal to the `SentinelIterator` iff. the
  /// `ValidIterator` has reached the end of the dataloader.
  bool operator==(const SentinelIterator<Batch>& /* unused */) const override {
    lazy_initialize();
    return !batch_;
  }
 
  /// Returns true if the memory address of `other` equals that of `this`.
  bool operator==(const ValidIterator<Batch>& other) const override {
    return &other == this;
  }
 
  /// Gets the very first batch if it has not yet been fetched.
  void lazy_initialize() const {
    if (!initialized_) {
      batch_ = next_batch_();
      initialized_ = true;
    }
  }
 
  BatchProducer next_batch_;
  mutable optional<Batch> batch_;
  mutable bool initialized_ = false;
};
 
template <typename Batch>
struct SentinelIterator : public IteratorImpl<Batch> {
  void next() override {
    AT_ERROR(
        "Incrementing the DataLoader's past-the-end iterator is not allowed");
  }
 
  Batch& get() override {
    AT_ERROR(
        "Dereferencing the DataLoader's past-the-end iterator is not allowed");
  }
 
  /// Does double dispatch.
  bool operator==(const IteratorImpl<Batch>& other) const override {
    return other == *this;
  }
 
  /// Calls the comparison operator between `ValidIterator` and
  /// `SentinelIterator`.
  bool operator==(const ValidIterator<Batch>& other) const override {
    return other == *this;
  }
 
  /// Sentinel iterators always compare equal.
  bool operator==(const SentinelIterator<Batch>& other) const override {
    return true;
  }
};
} // namespace detail
 
template <typename Batch>
class Iterator {
 public:
  // Type aliases to make the class recognized as a proper iterator.
  using difference_type = std::ptrdiff_t;
  using value_type = Batch;
  using pointer = Batch*;
  using reference = Batch&;
  using iterator_category = std::input_iterator_tag;
 
  explicit Iterator(std::unique_ptr<detail::IteratorImpl<Batch>> impl)
      : impl_(std::move(impl)) {}
 
  /// Increments the iterator.
  /// Only permitted for valid iterators (not past the end).
  Iterator& operator++() {
    impl_->next();
    return *this;
  }
 
  /// Returns the current batch.
  /// Only permitted for valid iterators (not past the end).
  Batch& operator*() {
    return impl_->get();
  }
 
  /// Returns a pointer to the current batch.
  /// Only permitted for valid iterators (not past the end).
  Batch* operator->() {
    return &impl_->get();
  }
 
  /// Compares two iterators for equality.
  bool operator==(const Iterator& other) const {
    return *impl_ == *other.impl_;
  }
 
  /// Compares two iterators for inequality.
  bool operator!=(const Iterator& other) const {
    return !(*this == other);
  }
 
 private:
  /// Points either to a `ValidIterator` or to a `SentinelIterator`.
  std::shared_ptr<detail::IteratorImpl<Batch>> impl_;
};
} // namespace data
} // namespace torch