reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-10 c3765bd24fe73747688a0ec2a550f219c9acb384
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#pragma once
 
#include <torch/nn/module.h>
#include <torch/types.h>
#include <torch/utils.h>
 
#include <c10/core/TensorOptions.h>
#include <c10/util/Exception.h>
 
#include <memory>
#include <utility>
 
namespace torch {
namespace nn {
/// The `clone()` method in the base `Module` class does not have knowledge of
/// the concrete runtime type of its subclasses. Therefore, `clone()` must
/// either be called from within the subclass, or from a base class that has
/// knowledge of the concrete type. `Cloneable` uses the CRTP to gain
/// knowledge of the subclass' static type and provide an implementation of the
/// `clone()` method. We do not want to use this pattern in the base class,
/// because then storing a module would always require templatizing it.
template <typename Derived>
class Cloneable : public virtual Module {
 public:
  using Module::Module;
 
  /// `reset()` must perform initialization of all members with reference
  /// semantics, most importantly parameters, buffers and submodules.
  virtual void reset() = 0;
 
  /// Performs a recursive "deep copy" of the `Module`, such that all parameters
  /// and submodules in the cloned module are different from those in the
  /// original module.
  std::shared_ptr<Module> clone(
      const optional<Device>& device = nullopt) const override {
    NoGradGuard no_grad;
 
    const auto& self = static_cast<const Derived&>(*this);
    auto copy = std::make_shared<Derived>(self);
    copy->parameters_.clear();
    copy->buffers_.clear();
    copy->children_.clear();
    copy->reset();
    TORCH_CHECK(
        copy->parameters_.size() == parameters_.size(),
        "The cloned module does not have the same number of "
        "parameters as the original module after calling reset(). "
        "Are you sure you called register_parameter() inside reset() "
        "and not the constructor?");
    for (const auto& parameter : parameters_) {
      auto& tensor = *parameter;
      auto data = device && tensor.device() != *device ?
          tensor.to(*device) : autograd::Variable(tensor).clone();
      copy->parameters_[parameter.key()].set_data(data);
    }
    TORCH_CHECK(
        copy->buffers_.size() == buffers_.size(),
        "The cloned module does not have the same number of "
        "buffers as the original module after calling reset(). "
        "Are you sure you called register_buffer() inside reset() "
        "and not the constructor?");
    for (const auto& buffer : buffers_) {
      auto& tensor = *buffer;
      auto data = device && tensor.device() != *device ?
          tensor.to(*device) : autograd::Variable(tensor).clone();
      copy->buffers_[buffer.key()].set_data(data);
    }
    TORCH_CHECK(
        copy->children_.size() == children_.size(),
        "The cloned module does not have the same number of "
        "child modules as the original module after calling reset(). "
        "Are you sure you called register_module() inside reset() "
        "and not the constructor?");
    for (const auto& child : children_) {
      copy->children_[child.key()]->clone_(*child.value(), device);
    }
    return copy;
  }
 
 private:
  void clone_(Module& other, const optional<Device>& device) final {
    // Here we are *pretty* certain that `other's` type is `Derived` (because it
    // was registered under the same name as `this`), but you never know what
    // crazy things `reset()` does, so `dynamic_cast` just to be safe.
    auto clone = std::dynamic_pointer_cast<Derived>(other.clone(device));
    TORCH_CHECK(
        clone != nullptr,
        "Attempted to clone submodule, but it is of a "
        "different type than the submodule it was to be cloned into");
    static_cast<Derived&>(*this) = std::move(*clone);
  }
};
 
} // namespace nn
} // namespace torch