reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-11 bdf3ad71583fb4ef100d3819ecdae8fd9f70083e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
#pragma once
 
#include <torch/detail/static.h>
#include <torch/nn/module.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
 
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/utils/memory.h>
#include <torch/csrc/utils/variadic.h>
 
#include <ATen/Device.h>
 
#include <memory>
#include <type_traits>
#include <typeinfo>
#include <utility>
#include <vector>
 
namespace torch {
namespace nn {
 
/// Stores a type erased `Module`.
///
/// The PyTorch C++ API does not impose an interface on the signature of
/// `forward()` in `Module` subclasses. This gives you complete freedom to
/// design your `forward()` methods to your liking. However, this also means
/// there is no unified base type you could store in order to call `forward()`
/// polymorphically for any module. This is where the `AnyModule` comes in.
/// Instead of inheritance, it relies on type erasure for polymorphism.
///
/// An `AnyModule` can store any `nn::Module` subclass that provides a
/// `forward()` method. This `forward()` may accept any types and return any
/// type. Once stored in an `AnyModule`, you can invoke the underlying module's
/// `forward()` by calling `AnyModule::forward()` with the arguments you would
/// supply to the stored module (though see one important limitation below).
/// Example:
///
/// \rst
/// .. code-block:: cpp
///
///   struct GenericTrainer {
///     torch::nn::AnyModule module;
///
///     void train(torch::Tensor input) {
///       module.forward(input);
///     }
///   };
///
///   GenericTrainer trainer1{torch::nn::Linear(3, 4)};
///   GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)};
/// \endrst
///
/// As `AnyModule` erases the static type of the stored module (and its
/// `forward()` method) to achieve polymorphism, type checking of arguments is
/// moved to runtime. That is, passing an argument with an incorrect type to an
/// `AnyModule` will compile, but throw an exception at runtime:
///
/// \rst
/// .. code-block:: cpp
///
///   torch::nn::AnyModule module(torch::nn::Linear(3, 4));
///   // Linear takes a tensor as input, but we are passing an integer.
///   // This will compile, but throw a `torch::Error` exception at runtime.
///   module.forward(123);
/// \endrst
///
/// \rst
/// .. attention::
///   One noteworthy limitation of `AnyModule` is that its `forward()` method
///   does not support implicit conversion of argument types. For example, if
///   the stored module's `forward()` method accepts a `float` and you call
///   `any_module.forward(3.4)` (where `3.4` is a `double`), this will throw
///   an exception.
/// \endrst
///
/// The return type of the `AnyModule`'s `forward()` method is controlled via
/// the first template argument to `AnyModule::forward()`. It defaults to
/// `torch::Tensor`. To change it, you can write `any_module.forward<int>()`,
/// for example.
///
/// \rst
/// .. code-block:: cpp
///
///   torch::nn::AnyModule module(torch::nn::Linear(3, 4));
///   auto output = module.forward(torch::ones({2, 3}));
///
///   struct IntModule {
///     int forward(int x) { return x; }
///   };
///   torch::nn::AnyModule module(IntModule{});
///   int output = module.forward<int>(5);
/// \endrst
///
/// The only other method an `AnyModule` provides access to on the stored
/// module is `clone()`. However, you may acquire a handle on the module via
/// `.ptr()`, which returns a `shared_ptr<nn::Module>`. Further, if you know
/// the concrete type of the stored module, you can get a concrete handle to it
/// using `.get<T>()` where `T` is the concrete module type.
///
/// \rst
/// .. code-block:: cpp
///
///   torch::nn::AnyModule module(torch::nn::Linear(3, 4));
///   std::shared_ptr<nn::Module> ptr = module.ptr();
///   torch::nn::Linear linear(module.get<torch::nn::Linear>());
/// \endrst
class AnyModule {
 public:
  /// A type-erased value.
  class Value;
 
  /// A default-constructed `AnyModule` is in an empty state.
  AnyModule() = default;
 
  /// Constructs an `AnyModule` from a `shared_ptr` to concrete module object.
  template <typename ModuleType>
  explicit AnyModule(std::shared_ptr<ModuleType> module);
 
  /// Constructs an `AnyModule` from a concrete module object.
  template <
      typename ModuleType,
      typename = torch::detail::enable_if_module_t<ModuleType>>
  explicit AnyModule(ModuleType&& module);
 
  /// Constructs an `AnyModule` from a module holder.
  template <typename ModuleType>
  explicit AnyModule(const ModuleHolder<ModuleType>& module_holder);
 
  /// Move construction and assignment is allowed, and follows the default
  /// behavior of move for `std::unique_ptr`.
  AnyModule(AnyModule&&) = default;
  AnyModule& operator=(AnyModule&&) = default;
 
  /// Creates a shallow copy of an `AnyModule`.
  AnyModule(const AnyModule& other);
  AnyModule& operator=(const AnyModule& other);
 
  /// Creates a deep copy of an `AnyModule` if it contains a module, else an
  /// empty `AnyModule` if it is empty.
  AnyModule clone(optional<Device> device = nullopt) const;
 
  /// Assigns a module to the `AnyModule` (to circumvent the explicit
  /// constructor).
  template <typename ModuleType>
  AnyModule& operator=(std::shared_ptr<ModuleType> module);
 
  /// Invokes `forward()` on the contained module with the given arguments, and
  /// returns the return value as an `Value`. Use this method when chaining
  /// `AnyModule`s in a loop.
  template <typename... ArgumentTypes>
  Value any_forward(ArgumentTypes&&... arguments);
 
  /// Invokes `forward()` on the contained module with the given arguments, and
  /// casts the returned `Value` to the supplied `ReturnType` (which defaults to
  /// `torch::Tensor`).
  template <typename ReturnType = torch::Tensor, typename... ArgumentTypes>
  ReturnType forward(ArgumentTypes&&... arguments);
 
  /// Attempts to cast the underlying module to the given module type. Throws an
  /// exception if the types do not match.
  template <typename T, typename = torch::detail::enable_if_module_t<T>>
  T& get();
 
  /// Attempts to cast the underlying module to the given module type. Throws an
  /// exception if the types do not match.
  template <typename T, typename = torch::detail::enable_if_module_t<T>>
  const T& get() const;
 
  /// Returns the contained module in a `nn::ModuleHolder` subclass if possible
  /// (i.e. if `T` has a constructor for the underlying module type).
  template <typename T, typename ContainedType = typename T::ContainedType>
  T get() const;
 
  /// Returns a `std::shared_ptr` whose dynamic type is that of the underlying
  /// module.
  std::shared_ptr<Module> ptr() const;
 
  /// Like `ptr()`, but casts the pointer to the given type.
  template <typename T, typename = torch::detail::enable_if_module_t<T>>
  std::shared_ptr<T> ptr() const;
 
  /// Returns the `type_info` object of the contained value.
  const std::type_info& type_info() const;
 
  /// Returns true if the `AnyModule` does not contain a module.
  bool is_empty() const noexcept;
 
 private:
  /// \internal
  /// The static type of the object we store in the `AnyModule`, which erases
  /// the actual type, but allows us to call `forward()` on the underlying
  /// module.
  struct Placeholder;
 
  /// \internal
  /// The dynamic type of the object stored in the `AnyModule`. It contains the
  /// concrete instance to which all calls are forwarded. It is parameterized
  /// over the concrete type of the module, and the types of the arguments the
  /// module takes in its `forward()` method.
  template <typename ModuleType, typename... ArgumentTypes>
  struct Holder;
 
  /// Creates a `unique_ptr<Placeholder>` pointing to a `Holder` of the correct
  /// type. This method is used to deduce the arguments of the module's
  /// `forward()` method.
  template <
      typename ModuleType,
      typename Class,
      typename ReturnType,
      typename... ArgumentTypes>
  std::unique_ptr<Placeholder> make_holder(
      std::shared_ptr<ModuleType>&& module,
      ReturnType (Class::*)(ArgumentTypes...));
 
  /// Helper method invoked by const and non-const `get()`.
  template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
  ModuleType& get_(ReturnType (ModuleType::*)(ArgumentTypes...)) const;
 
  /// Helper method invoked by const and non-const `get()`.
  template <typename ModuleType>
  ModuleType& get_() const;
 
  /// The type erased module.
  std::unique_ptr<Placeholder> content_;
};
 
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule::Value ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
/// A simplified implementation of `std::any` which stores
/// a type erased object, whose concrete value can be retrieved at runtime by
/// checking if the `typeid()` of a requested type matches the `typeid()` of
/// the object stored. It is simplified in that it does not handle copying, as
/// we do not require it for our use cases. Moves are sufficient.
class AnyModule::Value {
 public:
  /// Move construction and assignment is allowed, and follows the default
  /// behavior of move for `std::unique_ptr`.
  Value(Value&&) = default;
  Value& operator=(Value&&) = default;
 
  /// Copy is disallowed, because we don't need it.
  Value(const Value& other) = delete;
  Value& operator=(const Value& other) = delete;
 
  /// Returns a pointer to the value contained in the `Value` if the type passed
  /// as template parameter matches the type of the value stored, and returns a
  /// null pointer otherwise.
  template <typename T>
  T* try_get() {
    static_assert(
        !std::is_reference<T>::value,
        "Value stores decayed types, you cannot cast it to a reference type");
    static_assert(
        !std::is_array<T>::value,
        "Value stores decayed types, you must cast it to T* instead of T[]");
    if (typeid(T).hash_code() == type_info().hash_code()) {
      return &static_cast<Holder<T>&>(*content_).value;
    }
    return nullptr;
  }
 
  /// Returns the value contained in the `Value` if the type passed as template
  /// parameter matches the type of the value stored, and throws an exception
  /// otherwise.
  template <typename T>
  T get() {
    if (auto* maybe_value = try_get<T>()) {
      return *maybe_value;
    }
    AT_ERROR(
        "Attempted to cast Value to ",
        c10::demangle(typeid(T).name()),
        ", but its actual type is ",
        c10::demangle(type_info().name()));
  }
 
  /// Returns the `type_info` object of the contained value.
  const std::type_info& type_info() const noexcept {
    return content_->type_info;
  }
 
 private:
  friend class AnyModule;
  friend struct TestValue;
 
  /// Constructs the `Value` from value type.
  template <
      typename T,
      typename =
          torch::disable_if_t<std::is_same<autograd::Variable, T>::value>>
  explicit Value(T&& value)
      : content_(
            torch::make_unique<Holder<decay_t<T>>>(std::forward<T>(value))) {}
 
  /// Constructs the `Value` from an `autograd::Variable`, first converting it
  /// to a `torch::Tensor`.
  explicit Value(autograd::Variable variable)
      : Value(Tensor(std::move(variable))) {}
 
  /// \internal
  /// The static type of the object we store in the `Value`, which erases the
  /// actual object's type, allowing us only to check the `type_info` of the
  /// type stored in the dynamic type.
  struct Placeholder {
    explicit Placeholder(const std::type_info& type_info_) noexcept
        : type_info(type_info_) {}
    virtual ~Placeholder() = default;
    const std::type_info& type_info;
  };
 
  /// \internal
  /// The dynamic type of the object we store in the `Value`, which hides the
  /// actual object we have erased in this `Value`.
  template <typename T>
  struct Holder : public Placeholder {
    /// A template because T&& would not be universal reference here.
    template <typename U>
    explicit Holder(U&& value_) noexcept
        : Placeholder(typeid(T)), value(std::forward<U>(value_)) {}
    T value;
  };
 
  /// The type erased object.
  std::unique_ptr<Placeholder> content_;
};
 
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule::Placeholder ~~~~~~~~~~~~~~~~~~~~~~~~~~
 
struct AnyModule::Placeholder : public AnyModule::Value::Placeholder {
  using AnyModule::Value::Placeholder::Placeholder;
 
  /// The "erased" `forward()` method.
  virtual Value forward(std::vector<Value>&& arguments) = 0;
 
  /// Returns std::shared_ptr<Module> pointing to the erased module.
  virtual std::shared_ptr<Module> ptr() = 0;
 
  /// Returns a `Placeholder` with a shallow copy of this `AnyModule`.
  virtual std::unique_ptr<Placeholder> copy() const = 0;
 
  /// Returns a `Placeholder` with a deep copy of this `AnyModule`.
  virtual std::unique_ptr<Placeholder> clone(optional<Device> device) const = 0;
};
 
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule::Holder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
template <typename ModuleType, typename... ArgumentTypes>
struct AnyModule::Holder : public AnyModule::Placeholder {
  /// \internal
  struct CheckedGetter {
    template <typename T>
    decay_t<T>&& operator()(size_t index) {
      AT_ASSERT(index < arguments_.size());
      auto& value = arguments_[index];
      if (auto* maybe_value = value.template try_get<decay_t<T>>()) {
        return std::move(*maybe_value);
      }
      AT_ERROR(
          "Expected argument #",
          index,
          " to be of type ",
          c10::demangle(typeid(T).name()),
          ", but received value of type ",
          c10::demangle(value.type_info().name()));
    }
    std::vector<Value>& arguments_;
  };
 
  /// \internal
  struct InvokeForward {
    template <typename... Ts>
    Value operator()(Ts&&... ts) {
      return Value(module_->forward(std::forward<Ts>(ts)...));
    }
    std::shared_ptr<ModuleType>& module_;
  };
 
  /// Constructs the `Holder` from a concrete module.
  explicit Holder(std::shared_ptr<ModuleType>&& module_)
      : Placeholder(typeid(ModuleType)), module(std::move(module_)) {}
 
  /// Calls `forward()` on the underlying module, casting each `Value` in the
  /// argument vector to a concrete value.
  Value forward(std::vector<Value>&& arguments) override {
    TORCH_CHECK(
        arguments.size() == sizeof...(ArgumentTypes),
        c10::demangle(type_info.name()),
        "'s forward() method expects ",
        sizeof...(ArgumentTypes),
        " arguments, but received ",
        arguments.size());
    // FYI: During invocation of a module's `forward()` method, the values live
    // in the `arguments` vector inside this function.
    return torch::unpack<Value, ArgumentTypes...>(
        InvokeForward{module}, CheckedGetter{arguments});
  }
 
  std::shared_ptr<Module> ptr() override {
    return module;
  }
 
  std::unique_ptr<Placeholder> copy() const override {
    return torch::make_unique<Holder>(*this);
  }
 
  std::unique_ptr<Placeholder> clone(optional<Device> device) const override {
    return torch::make_unique<Holder>(
        std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
  }
 
  /// The actual concrete module instance.
  std::shared_ptr<ModuleType> module;
};
 
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
template <typename ModuleType>
AnyModule::AnyModule(std::shared_ptr<ModuleType> module)
    : content_(make_holder(
          std::move(module),
          &std::remove_reference<ModuleType>::type::forward)) {
  // `AnyModule` can only store an `nn::Module` subclass object that provides
  // a `forward()` method that has a non-templatized return type.
  // (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s
  // `forward()` method has a templatized return type.)
  static_assert(
      torch::detail::is_module<ModuleType>::value,
      "Can only store object derived from nn::Module into AnyModule");
  static_assert(
      torch::detail::has_forward<ModuleType>::value,
      "Can only store module with a forward() method that has a non-templatized"
      " argument type and return type into AnyModule (e.g. we cannot store nn::Sequential"
      "into AnyModule, because its forward() method's argument type and return type are templatized."
      " If you need to use nn::Sequentials inside each other you can subclass "
      "nn::Sequential and write a non-templatized forward function for it. You can checkout "
      "https://github.com/pytorch/vision/blob/2f46070f3cb1ea894d82578f3dc5677f82f34958/torchvision/csrc/models/mnasnet.cpp#L59 "
      "for an example on how to do this.).");
}
 
template <typename ModuleType, typename>
AnyModule::AnyModule(ModuleType&& module)
    : AnyModule(
          std::make_shared<ModuleType>(std::forward<ModuleType>(module))) {}
 
template <typename ModuleType>
AnyModule::AnyModule(const ModuleHolder<ModuleType>& module_holder)
    : AnyModule(module_holder.ptr()) {}
 
inline AnyModule::AnyModule(const AnyModule& other)
    : content_(other.content_ ? other.content_->copy() : nullptr) {}
 
inline AnyModule& AnyModule::operator=(const AnyModule& other) {
  if (this != &other) {
    content_ = other.content_ ? other.content_->copy() : nullptr;
  }
  return *this;
}
 
inline AnyModule AnyModule::clone(optional<Device> device) const {
  AnyModule clone;
  clone.content_ = content_ ? content_->clone(device) : nullptr;
  return clone;
}
 
template <typename ModuleType>
AnyModule& AnyModule::operator=(std::shared_ptr<ModuleType> module) {
  return (*this = AnyModule(std::move(module)));
}
 
template <typename... ArgumentTypes>
AnyModule::Value AnyModule::any_forward(ArgumentTypes&&... arguments) {
  TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
  std::vector<Value> values;
  values.reserve(sizeof...(ArgumentTypes));
  torch::apply(
      [&values](Value&& value) { values.push_back(std::move(value)); },
      Value(std::forward<ArgumentTypes>(arguments))...);
  return content_->forward(std::move(values));
}
 
template <typename ReturnType, typename... ArgumentTypes>
ReturnType AnyModule::forward(ArgumentTypes&&... arguments) {
  return any_forward(std::forward<ArgumentTypes>(arguments)...)
      .template get<ReturnType>();
}
 
template <typename T, typename>
T& AnyModule::get() {
  TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
  return get_<T>();
}
 
template <typename T, typename>
const T& AnyModule::get() const {
  TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
  return get_<T>();
}
 
template <typename T, typename ContainedType>
T AnyModule::get() const {
  return T(ptr<ContainedType>());
}
 
inline std::shared_ptr<Module> AnyModule::ptr() const {
  TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
  return content_->ptr();
}
 
template <typename T, typename>
std::shared_ptr<T> AnyModule::ptr() const {
  TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
  // Call get() but discard the value, just to do the type checking.
  get_<T>();
  return std::dynamic_pointer_cast<T>(ptr());
}
 
inline const std::type_info& AnyModule::type_info() const {
  TORCH_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule");
  return content_->type_info;
}
 
inline bool AnyModule::is_empty() const noexcept {
  return content_ == nullptr;
}
 
// Private Methods
 
template <
    typename ModuleType,
    typename Class,
    typename ReturnType,
    typename... ArgumentTypes>
std::unique_ptr<AnyModule::Placeholder> AnyModule::make_holder(
    std::shared_ptr<ModuleType>&& module,
    ReturnType (Class::*)(ArgumentTypes...)) {
  static_assert(
      torch::detail::check_not_lvalue_references<ArgumentTypes...>(),
      "Modules stored inside AnyModule must not take references. "
      "Use pointers instead.");
  static_assert(
      !std::is_void<ReturnType>::value,
      "AnyModule cannot store modules that return void "
      "(you can return a dummy value).");
  return torch::make_unique<Holder<decay_t<ModuleType>, ArgumentTypes...>>(
      std::move(module));
}
 
template <typename ModuleType>
ModuleType& AnyModule::get_() const {
  using M = typename std::remove_reference<ModuleType>::type;
  static_assert(
      torch::detail::has_forward<M>::value,
      "Can only call AnyModule::get<T> with a type T that has a forward method");
  return get_(&M::forward);
}
 
template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
ModuleType& AnyModule::get_(
    ReturnType (ModuleType::*)(ArgumentTypes...)) const {
  if (typeid(ModuleType).hash_code() == type_info().hash_code()) {
    return *static_cast<Holder<ModuleType, ArgumentTypes...>&>(*content_)
                .module;
  }
  AT_ERROR(
      "Attempted to cast module of type ",
      c10::demangle(type_info().name()),
      " to type ",
      c10::demangle(typeid(ModuleType).name()));
}
 
} // namespace nn
} // namespace torch