#pragma once #include #include #include #include #include #include #include namespace torch { namespace nn { /// The `clone()` method in the base `Module` class does not have knowledge of /// the concrete runtime type of its subclasses. Therefore, `clone()` must /// either be called from within the subclass, or from a base class that has /// knowledge of the concrete type. `Cloneable` uses the CRTP to gain /// knowledge of the subclass' static type and provide an implementation of the /// `clone()` method. We do not want to use this pattern in the base class, /// because then storing a module would always require templatizing it. template class Cloneable : public virtual Module { public: using Module::Module; /// `reset()` must perform initialization of all members with reference /// semantics, most importantly parameters, buffers and submodules. virtual void reset() = 0; /// Performs a recursive "deep copy" of the `Module`, such that all parameters /// and submodules in the cloned module are different from those in the /// original module. std::shared_ptr clone( const optional& device = nullopt) const override { NoGradGuard no_grad; const auto& self = static_cast(*this); auto copy = std::make_shared(self); copy->parameters_.clear(); copy->buffers_.clear(); copy->children_.clear(); copy->reset(); TORCH_CHECK( copy->parameters_.size() == parameters_.size(), "The cloned module does not have the same number of " "parameters as the original module after calling reset(). " "Are you sure you called register_parameter() inside reset() " "and not the constructor?"); for (const auto& parameter : parameters_) { auto& tensor = *parameter; auto data = device && tensor.device() != *device ? tensor.to(*device) : autograd::Variable(tensor).clone(); copy->parameters_[parameter.key()].set_data(data); } TORCH_CHECK( copy->buffers_.size() == buffers_.size(), "The cloned module does not have the same number of " "buffers as the original module after calling reset(). " "Are you sure you called register_buffer() inside reset() " "and not the constructor?"); for (const auto& buffer : buffers_) { auto& tensor = *buffer; auto data = device && tensor.device() != *device ? tensor.to(*device) : autograd::Variable(tensor).clone(); copy->buffers_[buffer.key()].set_data(data); } TORCH_CHECK( copy->children_.size() == children_.size(), "The cloned module does not have the same number of " "child modules as the original module after calling reset(). " "Are you sure you called register_module() inside reset() " "and not the constructor?"); for (const auto& child : children_) { copy->children_[child.key()]->clone_(*child.value(), device); } return copy; } private: void clone_(Module& other, const optional& device) final { // Here we are *pretty* certain that `other's` type is `Derived` (because it // was registered under the same name as `this`), but you never know what // crazy things `reset()` does, so `dynamic_cast` just to be safe. auto clone = std::dynamic_pointer_cast(other.clone(device)); TORCH_CHECK( clone != nullptr, "Attempted to clone submodule, but it is of a " "different type than the submodule it was to be cloned into"); static_cast(*this) = std::move(*clone); } }; } // namespace nn } // namespace torch