reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-16 a47fccb11fa3470901aebcb27f861d242d0925e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
#pragma once
 
#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
 
#include <vector>
 
namespace torch {
namespace nn {
 
/// A list of `Module`s that registers its elements.
///
/// \rst
/// .. code-block:: cpp
///
///   torch::nn::ModuleList mlist(
///     torch::nn::Linear(3, 4),
///     torch::nn::BatchNorm(4),
///     torch::nn::Dropout(0.5)
///   );
///
///   for (const auto &module : mlist) {
///     module.pretty_print();
///   }
///
/// \endrst
///
/// Why should you use `ModuleList` instead of a simple `std::vector`? The value
/// a `ModuleList` provides over manually calling a sequence of modules is that
/// it allows treating the whole container *as a single module*, such that
/// performing a transformation on the `ModuleList` applies to each of the
/// modules it stores (which are each a registered submodule of the
/// `ModuleList`). For example, calling
/// `.to(torch::kCUDA)` on a `ModuleList` will move each module in the list to
/// CUDA memory. For example:
///
/// \rst
/// .. code-block:: cpp
///
///   torch::nn::ModuleList mlist(
///     torch::nn::Linear(3, 4),
///     torch::nn::BatchNorm(4),
///     torch::nn::Dropout(0.5)
///   );
///
///   // Convert all modules to CUDA.
///   mlist->to(torch::kCUDA);
///
/// \endrst
///
/// Finally, `ModuleList` provides a lightweight container API, such as allowing
/// iteration over submodules, positional access, adding a new module after
/// construction via `push_back`, as well as joining two `ModuleList`s via
/// `extend`.
class ModuleListImpl : public Cloneable<ModuleListImpl> {
 public:
  using Iterator = std::vector<std::shared_ptr<Module>>::iterator;
  using ConstIterator = std::vector<std::shared_ptr<Module>>::const_iterator;
 
  ModuleListImpl() = default;
 
  /// Constructs the `ModuleList` from a variadic list of modules.
  template <typename... Modules>
  explicit ModuleListImpl(Modules&&... modules) {
    modules_.reserve(sizeof...(Modules));
    push_back_var(std::forward<Modules>(modules)...);
  }
 
  /// Special cloning function for `ModuleList` because it does not use
  /// `reset()`.
  std::shared_ptr<Module> clone(
      const optional<Device>& device = nullopt) const override {
    auto clone = std::make_shared<ModuleListImpl>();
    for (const auto& module : modules_) {
      clone->push_back(module->clone(device));
    }
    return std::move(clone);
  }
 
  /// `reset()` is empty for `ModuleList`, since it does not have parameters of
  /// its own.
  void reset() override {}
 
  /// Pretty prints the `ModuleList` module into the given `stream`.
  void pretty_print(std::ostream& stream) const override {
    stream << "torch::nn::ModuleList";
  }
 
  void push_back(std::shared_ptr<Module> module) {
    modules_.push_back(std::move(module));
    const auto index = modules_.size() - 1;
    register_module(std::to_string(index), modules_[index]);
  }
 
  /// Adds a new `Module` to the `ModuleList` container, moving or copying
  /// it into a `shared_ptr` internally. This method allows passing value types,
  /// and letting the container deal with the boxing.
  template <typename M, typename = torch::detail::enable_if_module_t<M>>
  void push_back(M&& module) {
    using Type = typename std::remove_reference<M>::type;
    push_back(std::make_shared<Type>(std::forward<M>(module)));
  }
 
  /// Unwraps the contained module of a `ModuleHolder` and adds it to the
  /// `ModuleList`.
  template <typename M>
  void push_back(const ModuleHolder<M>& module_holder) {
    push_back(module_holder.ptr());
  }
 
  /// Iterates over the container and calls `push_back()` on each value.
  template <typename Container>
  void extend(const Container& container) {
    for (const auto& module : container) {
      push_back(module);
    }
  }
 
  /// Returns an iterator to the start of the `ModuleList`.
  Iterator begin() {
    return modules_.begin();
  }
 
  /// Returns a const iterator to the start of the `ModuleList`.
  ConstIterator begin() const {
    return modules_.begin();
  }
 
  /// Returns an iterator to the end of the `ModuleList`.
  Iterator end() {
    return modules_.end();
  }
 
  /// Returns a const iterator to the end of the `ModuleList`.
  ConstIterator end() const {
    return modules_.end();
  }
 
  /// Attempts to return the module at the given index as the requested type.
  /// Throws an exception if the index is out of bounds or the types do not
  /// match.
  template <typename T>
  T& at(size_t index) {
    static_assert(
        torch::detail::is_module<T>::value,
        "Can only call ModuleList::at with an nn::Module type");
    TORCH_CHECK(index < size(), "Index out of range");
    return *modules_[index]->as<T>();
  }
 
  /// Attempts to return the module at the given index as the requested type.
  /// Throws an exception if the index is out of bounds or the types do not
  /// match.
  template <typename T>
  const T& at(size_t index) const {
    static_assert(
        torch::detail::is_module<T>::value,
        "Can only call ModuleList::at with an nn::Module type");
    TORCH_CHECK(index < size(), "Index out of range");
    return *modules_[index]->as<T>();
  }
 
  /// Attempts to return a `std::shared_ptr` whose dynamic type is that of the
  /// underlying module at the given index. Throws an exception if the index is
  /// out of bounds.
  std::shared_ptr<Module> ptr(size_t index) const {
    TORCH_CHECK(index < size(), "Index out of range");
    return modules_[index];
  }
 
  /// Attempts to return a `std::shared_ptr` whose type is the one provided.
  /// Throws an exception if the index is out of bounds or the types do not
  /// match.
  template <typename T>
  std::shared_ptr<T> ptr(size_t index) const {
    static_assert(
        torch::detail::is_module<T>::value,
        "Can only call ModuleList::ptr with an nn::Module type");
    TORCH_CHECK(index < size(), "Index out of range");
    return std::dynamic_pointer_cast<T>(modules_[index]);
  }
 
  /// Like `ptr(index)`.
  std::shared_ptr<Module> operator[](size_t index) const {
    // This is the only method we can call without a type.
    return ptr(index);
  }
 
  /// The current size of the `ModuleList` container.
  size_t size() const noexcept {
    return modules_.size();
  }
 
  /// True if there are no modules in the `ModuleList`.
  bool is_empty() const noexcept {
    return size() == 0;
  }
 
  void insert(size_t index, std::shared_ptr<Module> module) {
    TORCH_CHECK(index <= size(), "Index out of range");
 
    if (index == size())
      push_back(module);
    else {
      modules_.insert(
          modules_.begin() + Iterator::difference_type(index),
          std::move(module));
 
      for (size_t i = index; i < size() - 1; ++i)
        replace_module(std::to_string(index), modules_[index]);
      register_module(std::to_string(size() - 1), modules_.back());
    }
  }
 
  /// Unwraps the contained module of a `ModuleHolder` and inserts it in the
  /// `ModuleList`.
  template <typename M>
  void insert(size_t index, const ModuleHolder<M>& module_holder) {
    insert(index, module_holder.ptr());
  }
 
  /// inserts a new `Module` to the `ModuleList` container, moving or copying
  /// it into a `shared_ptr` internally. This method allows passing value types,
  /// and letting the container deal with the boxing.
  template <typename M, typename = torch::detail::enable_if_module_t<M>>
  void insert(size_t index, M&& module) {
    using Type = typename std::remove_reference<M>::type;
    insert(index, std::make_shared<Type>(std::forward<M>(module)));
  }
 
 private:
  template <typename Head, typename... Tail>
  void push_back_var(Head&& head, Tail&&... tail) {
    push_back(std::forward<Head>(head));
    // Recursively calls this method, until the parameter pack only thas this
    // entry left. Then calls `push_back()` a final time (above).
    push_back_var(std::forward<Tail>(tail)...);
  }
 
  /// The base case, when the list of modules is empty.
  void push_back_var() {}
 
  // Box the AnyModules to give ModuleList reference semantics, like the rest of
  // the API. Note that this is not required otherwise, this could just be a
  // `vector<AnyModule>`.
  std::vector<std::shared_ptr<Module>> modules_;
};
 
/// A `ModuleHolder` subclass for `ModuleListImpl`.
/// See the documentation for `ModuleListImpl` class to learn what methods it
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(ModuleList);
 
} // namespace nn
} // namespace torch