reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-10 c3765bd24fe73747688a0ec2a550f219c9acb384
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#pragma once
 
#include <torch/detail/static.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/container/any.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
 
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/utils/memory.h>
#include <torch/csrc/utils/variadic.h>
 
#include <ATen/Device.h>
 
#include <initializer_list>
#include <memory>
#include <type_traits>
#include <typeinfo>
#include <utility>
#include <vector>
 
namespace torch {
namespace nn {
 
/// Stores a type erased `Module` with name.
///
/// The `NamedAnyModule` class and the `modules_ordered_dict(...)` function enables
/// the following API for constructing `nn::Sequential` with named submodules:
/// \rst
/// .. code-block:: cpp
///
///   struct M : torch::nn::Module {
///     explicit M(int value_) : value(value_) {}
///     int value;
///     int forward() {
///       return value;
///     }
///   };
///
///   Sequential sequential(modules_ordered_dict({
///     {"m1", std::make_shared<M>(1)},  // shared pointer to `Module` is supported
///     {std::string("m2"), M(2)},  // `Module` is supported
///     {"linear1", Linear(10, 3)}  // `ModuleHolder` is supported
///   }));
/// \endrst
///
/// Specifically, we design the signature of `modules_ordered_dict(...)` to be
/// `modules_ordered_dict(std::initializer_list<NamedAnyModule> named_modules)`, as
/// a result of evaluating the following possible approaches:
///
/// Approach 1:
/// `modules_ordered_dict(std::initializer_list<
///   torch::OrderedDict<std::string, ModuleType>::Item> named_modules)`
///
/// Why it doens't work:
/// When we pass in a braced-init list such as
/// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, at the template argument
/// deduction step the compiler is not able to deduce the type of `ModuleType` to
/// the type of `M(1)` or `M(2)`, since the compiler doesn't actually look into the
/// braced-init list `{"m1", M(1)}` and figure out what the types of its elements are.
///
/// Approach 2:
/// `modules_ordered_dict(std::initializer_list<
///   std::pair<std::string, AnyModule> named_modules)`
///
/// Why it doens't work:
/// When we pass in a braced-init list such as
/// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, the compiler is not able to
/// match `std::initializer_list<std::pair<std::string, AnyModule>>` to the nested
/// braced-init list `{{"m1", M(1)}, {"m2", M(2)}}`, and results in a "could not
/// convert" error.
///
/// Approach 3:
/// `modules_ordered_dict(std::initializer_list<NamedAnyModule> named_modules)`
///
/// Why it works:
/// When we pass in a braced-init list such as
/// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, the compiler is passing the
/// braced-init lists {"m1", M(1)} and {"m2", M(2)} to the `NamedAnyModule`
/// constructors, and the constructors are able to figure out the types of the
/// braced-init lists' elements and match to the correct module type.
 
class NamedAnyModule {
 public:
  /// Creates a `NamedAnyModule` from a (boxed) `Module`.
  template <typename ModuleType>
  NamedAnyModule(std::string name, std::shared_ptr<ModuleType> module_ptr)
      : NamedAnyModule(std::move(name), AnyModule(std::move(module_ptr))) {}
 
  /// Creates a `NamedAnyModule` from a `Module`, moving or copying it
  /// into a `shared_ptr` internally.
  // NOTE: We need to use `std::remove_reference<M>::type` to get rid of
  // any reference components for make_unique.
  template <typename M, typename = torch::detail::enable_if_module_t<M>>
  NamedAnyModule(std::string name, M&& module)
      : NamedAnyModule(
          std::move(name),
          std::make_shared<typename std::remove_reference<M>::type>(
            std::forward<M>(module))) {}
 
  /// Creates a `NamedAnyModule` from a `Module` that is unwrapped from
  /// a `ModuleHolder`.
  template <typename M>
  NamedAnyModule(std::string name, const ModuleHolder<M>& module_holder)
      : NamedAnyModule(std::move(name), module_holder.ptr()) {}
 
  /// Returns a reference to the name.
  const std::string& name() const noexcept {
    return name_;
  }
 
  /// Returns a reference to the module.
  AnyModule& module() noexcept {
    return module_;
  }
 
 private:
  /// Creates a `NamedAnyModule` from a type-erased `AnyModule`.
  NamedAnyModule(std::string name, AnyModule any_module)
    : name_(std::move(name)), module_(std::move(any_module)) {}
 
  std::string name_;
  AnyModule module_;
};
 
TORCH_API torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(
  std::initializer_list<NamedAnyModule> named_modules);
 
} // namespace nn
} // namespace torch