#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace nn { /// Stores a type erased `Module` with name. /// /// The `NamedAnyModule` class and the `modules_ordered_dict(...)` function enables /// the following API for constructing `nn::Sequential` with named submodules: /// \rst /// .. code-block:: cpp /// /// struct M : torch::nn::Module { /// explicit M(int value_) : value(value_) {} /// int value; /// int forward() { /// return value; /// } /// }; /// /// Sequential sequential(modules_ordered_dict({ /// {"m1", std::make_shared(1)}, // shared pointer to `Module` is supported /// {std::string("m2"), M(2)}, // `Module` is supported /// {"linear1", Linear(10, 3)} // `ModuleHolder` is supported /// })); /// \endrst /// /// Specifically, we design the signature of `modules_ordered_dict(...)` to be /// `modules_ordered_dict(std::initializer_list named_modules)`, as /// a result of evaluating the following possible approaches: /// /// Approach 1: /// `modules_ordered_dict(std::initializer_list< /// torch::OrderedDict::Item> named_modules)` /// /// Why it doens't work: /// When we pass in a braced-init list such as /// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, at the template argument /// deduction step the compiler is not able to deduce the type of `ModuleType` to /// the type of `M(1)` or `M(2)`, since the compiler doesn't actually look into the /// braced-init list `{"m1", M(1)}` and figure out what the types of its elements are. /// /// Approach 2: /// `modules_ordered_dict(std::initializer_list< /// std::pair named_modules)` /// /// Why it doens't work: /// When we pass in a braced-init list such as /// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, the compiler is not able to /// match `std::initializer_list>` to the nested /// braced-init list `{{"m1", M(1)}, {"m2", M(2)}}`, and results in a "could not /// convert" error. /// /// Approach 3: /// `modules_ordered_dict(std::initializer_list named_modules)` /// /// Why it works: /// When we pass in a braced-init list such as /// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, the compiler is passing the /// braced-init lists {"m1", M(1)} and {"m2", M(2)} to the `NamedAnyModule` /// constructors, and the constructors are able to figure out the types of the /// braced-init lists' elements and match to the correct module type. class NamedAnyModule { public: /// Creates a `NamedAnyModule` from a (boxed) `Module`. template NamedAnyModule(std::string name, std::shared_ptr module_ptr) : NamedAnyModule(std::move(name), AnyModule(std::move(module_ptr))) {} /// Creates a `NamedAnyModule` from a `Module`, moving or copying it /// into a `shared_ptr` internally. // NOTE: We need to use `std::remove_reference::type` to get rid of // any reference components for make_unique. template > NamedAnyModule(std::string name, M&& module) : NamedAnyModule( std::move(name), std::make_shared::type>( std::forward(module))) {} /// Creates a `NamedAnyModule` from a `Module` that is unwrapped from /// a `ModuleHolder`. template NamedAnyModule(std::string name, const ModuleHolder& module_holder) : NamedAnyModule(std::move(name), module_holder.ptr()) {} /// Returns a reference to the name. const std::string& name() const noexcept { return name_; } /// Returns a reference to the module. AnyModule& module() noexcept { return module_; } private: /// Creates a `NamedAnyModule` from a type-erased `AnyModule`. NamedAnyModule(std::string name, AnyModule any_module) : name_(std::move(name)), module_(std::move(any_module)) {} std::string name_; AnyModule module_; }; TORCH_API torch::OrderedDict modules_ordered_dict( std::initializer_list named_modules); } // namespace nn } // namespace torch