reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
#pragma once
 
#include <torch/detail/static.h>
#include <torch/nn/module.h>
#include <torch/ordered_dict.h>
#include <torch/types.h>
 
#include <torch/csrc/Device.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
 
#include <iterator>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
 
namespace torch {
namespace python {
namespace detail {
inline Device py_object_to_device(py::object object) {
  PyObject* obj = object.ptr();
  if (THPDevice_Check(obj)) {
    return reinterpret_cast<THPDevice*>(obj)->device;
  }
  throw TypeError("Expected device");
}
 
inline Dtype py_object_to_dtype(py::object object) {
  PyObject* obj = object.ptr();
  if (THPDtype_Check(obj)) {
    return reinterpret_cast<THPDtype*>(obj)->scalar_type;
  }
  throw TypeError("Expected dtype");
}
 
template <typename ModuleType>
using PyModuleClass =
    py::class_<ModuleType, torch::nn::Module, std::shared_ptr<ModuleType>>;
 
/// Dynamically creates a subclass of `torch.nn.cpp.ModuleWrapper` that is also
/// a subclass of `torch.nn.Module`, and passes it the user-provided C++ module
/// to which it delegates all calls.
template <typename ModuleType>
void bind_cpp_module_wrapper(
    py::module module,
    PyModuleClass<ModuleType> cpp_class,
    const char* name) {
  // Grab the `torch.nn.cpp.ModuleWrapper` class, which we'll subclass
  // with a dynamically created class below.
  py::object cpp_module =
      py::module::import("torch.nn.cpp").attr("ModuleWrapper");
 
  // Grab the `type` class which we'll use as a metaclass to create a new class
  // dynamically.
  py::object type_metaclass =
      py::reinterpret_borrow<py::object>((PyObject*)&PyType_Type);
 
  // The `ModuleWrapper` constructor copies all functions to its own `__dict__`
  // in its constructor, but we do need to give our dynamic class a constructor.
  // Inside, we construct an instance of the original C++ module we're binding
  // (the `torch::nn::Module` subclass), and then forward it to the
  // `ModuleWrapper` constructor.
  py::dict attributes;
 
  // `type()` always needs a `str`, but pybind11's `str()` method always creates
  // a `unicode` object.
#if PY_MAJOR_VERSION < 3
  py::object name_str =
      py::reinterpret_steal<py::object>(PyString_FromString(name));
#else
  py::object name_str = py::str(name);
#endif
 
  // Dynamically create the subclass of `ModuleWrapper`, which is a subclass of
  // `torch.nn.Module`, and will delegate all calls to the C++ module we're
  // binding.
  py::object wrapper_class =
      type_metaclass(name_str, py::make_tuple(cpp_module), attributes);
 
  // The constructor of the dynamic class calls `ModuleWrapper.__init__()`,
  // which replaces its methods with those of the C++ module.
  wrapper_class.attr("__init__") = py::cpp_function(
      [cpp_module, cpp_class](
          py::object self, py::args args, py::kwargs kwargs) {
        cpp_module.attr("__init__")(self, cpp_class(*args, **kwargs));
      },
      py::is_method(wrapper_class));
 
  // Calling `my_module.my_class` now means that `my_class` is a subclass of
  // `ModuleWrapper`, and whose methods call into the C++ module we're binding.
  module.attr(name) = wrapper_class;
}
} // namespace detail
 
/// Adds method bindings for a pybind11 `class_` that binds an `nn::Module`
/// subclass.
///
/// Say you have a pybind11 class object created with `py::class_<Net>(m,
/// "Net")`. This function will add all the necessary `.def()` calls to bind the
/// `nn::Module` base class' methods, such as `train()`, `eval()` etc. into
/// Python.
///
/// Users should prefer to use `bind_module` if possible.
template <typename ModuleType, typename... Extra>
py::class_<ModuleType, Extra...> add_module_bindings(
    py::class_<ModuleType, Extra...> module) {
  // clang-format off
  return module
      .def("train",
          [](ModuleType& module, bool mode) { module.train(mode); },
          py::arg("mode") = true)
      .def("eval", [](ModuleType& module) { module.eval(); })
      .def("clone", [](ModuleType& module) { return module.clone(); })
      .def_property_readonly(
          "training", [](ModuleType& module) { return module.is_training(); })
      .def("zero_grad", [](ModuleType& module) { module.zero_grad(); })
      .def_property_readonly( "_parameters", [](ModuleType& module) {
            return module.named_parameters(/*recurse=*/false);
          })
      .def("parameters", [](ModuleType& module, bool recurse) {
            return module.parameters(recurse);
          },
          py::arg("recurse") = true)
      .def("named_parameters", [](ModuleType& module, bool recurse) {
            return module.named_parameters(recurse);
          },
          py::arg("recurse") = true)
      .def_property_readonly("_buffers", [](ModuleType& module) {
            return module.named_buffers(/*recurse=*/false);
          })
      .def("buffers", [](ModuleType& module, bool recurse) {
            return module.buffers(recurse); },
          py::arg("recurse") = true)
      .def("named_buffers", [](ModuleType& module, bool recurse) {
            return module.named_buffers(recurse);
          },
          py::arg("recurse") = true)
      .def_property_readonly(
        "_modules", [](ModuleType& module) { return module.named_children(); })
      .def("modules", [](ModuleType& module) { return module.modules(); })
      .def("named_modules",
          [](ModuleType& module, py::object /* unused */, std::string prefix) {
            return module.named_modules(std::move(prefix));
          },
          py::arg("memo") = py::none(),
          py::arg("prefix") = std::string())
      .def("children", [](ModuleType& module) { return module.children(); })
      .def("named_children",
          [](ModuleType& module) { return module.named_children(); })
      .def("to", [](ModuleType& module, py::object object, bool non_blocking) {
            if (THPDevice_Check(object.ptr())) {
              module.to(
                  reinterpret_cast<THPDevice*>(object.ptr())->device,
                  non_blocking);
            } else {
              module.to(detail::py_object_to_dtype(object), non_blocking);
            }
          },
          py::arg("dtype_or_device"),
          py::arg("non_blocking") = false)
      .def("to",
          [](ModuleType& module,
             py::object device,
             py::object dtype,
             bool non_blocking) {
              if (device.is_none()) {
                module.to(detail::py_object_to_dtype(dtype), non_blocking);
              } else if (dtype.is_none()) {
                module.to(detail::py_object_to_device(device), non_blocking);
              } else {
                module.to(
                    detail::py_object_to_device(device),
                    detail::py_object_to_dtype(dtype),
                    non_blocking);
              }
          },
          py::arg("device"),
          py::arg("dtype"),
          py::arg("non_blocking") = false)
      .def("cuda", [](ModuleType& module) { module.to(kCUDA); })
      .def("cpu", [](ModuleType& module) { module.to(kCPU); })
      .def("float", [](ModuleType& module) { module.to(kFloat32); })
      .def("double", [](ModuleType& module) { module.to(kFloat64); })
      .def("half", [](ModuleType& module) { module.to(kFloat16); })
      .def("__str__", [](ModuleType& module) { return module.name(); })
      .def("__repr__", [](ModuleType& module) { return module.name(); });
  // clang-format on
}
 
/// Creates a pybind11 class object for an `nn::Module` subclass type and adds
/// default bindings.
///
/// After adding the default bindings, the class object is returned, such that
/// you can add more bindings.
///
/// Example usage:
/// \rst
/// .. code-block:: cpp
///
///   struct Net : torch::nn::Module {
///     Net(int in, int out) { }
///     torch::Tensor forward(torch::Tensor x) { return x; }
///   };
///
///   PYBIND11_MODULE(my_module, m) {
///     torch::python::bind_module<Net>(m, "Net")
///       .def(py::init<int, int>())
///       .def("forward", &Net::forward);
///  }
/// \endrst
template <typename ModuleType, bool force_enable = false>
torch::disable_if_t<
    torch::detail::has_forward<ModuleType>::value && !force_enable,
    detail::PyModuleClass<ModuleType>>
bind_module(py::module module, const char* name) {
  py::module cpp = module.def_submodule("cpp");
  auto cpp_class =
      add_module_bindings(detail::PyModuleClass<ModuleType>(cpp, name));
  detail::bind_cpp_module_wrapper(module, cpp_class, name);
  return cpp_class;
}
 
/// Creates a pybind11 class object for an `nn::Module` subclass type and adds
/// default bindings.
///
/// After adding the default bindings, the class object is returned, such that
/// you can add more bindings.
///
/// If the class has a `forward()` method, it is automatically exposed as
/// `forward()` and `__call__` in Python.
///
/// Example usage:
/// \rst
/// .. code-block:: cpp
///
///   struct Net : torch::nn::Module {
///     Net(int in, int out) { }
///     torch::Tensor forward(torch::Tensor x) { return x; }
///   };
///
///   PYBIND11_MODULE(my_module, m) {
///     torch::python::bind_module<Net>(m, "Net")
///       .def(py::init<int, int>())
///       .def("forward", &Net::forward);
///  }
/// \endrst
template <
    typename ModuleType,
    typename =
        torch::enable_if_t<torch::detail::has_forward<ModuleType>::value>>
detail::PyModuleClass<ModuleType> bind_module(
    py::module module,
    const char* name) {
  return bind_module<ModuleType, /*force_enable=*/true>(module, name)
      .def("forward", &ModuleType::forward)
      .def("__call__", &ModuleType::forward);
}
} // namespace python
} // namespace torch