#pragma once #include #include #include #include #include #include #include #include namespace torch { namespace detail { // Dump all the template metaprogramming in this file. #include } // namespace detail namespace nn { /// A `ModuleHolder` is essentially a wrapper around `std::shared_ptr` where /// `M` is an `nn::Module` subclass, with convenient constructors defined for /// the kind of constructions we want to allow for our modules. template class ModuleHolder : torch::detail::ModuleHolderIndicator { protected: /// The module pointer this class wraps. /// NOTE: Must be placed at the top of the class so that we can use it with /// trailing return types below. std::shared_ptr impl_; public: using ContainedType = Contained; /// Default constructs the contained module if if has a default constructor, /// else produces a static error. /// /// NOTE: This uses the behavior of template /// classes in C++ that constructors (or any methods) are only compiled when /// actually used. ModuleHolder() : impl_(default_construct()) { static_assert( std::is_default_constructible::value, "You are trying to default construct a module which has " "no default constructor. Use = nullptr to give it the empty state " "(e.g. `Linear linear = nullptr;` instead of `Linear linear;`)."); } /// Constructs the `ModuleHolder` with an empty contained value. Access to /// the underlying module is not permitted and will throw an exception, until /// a value is assigned. /* implicit */ ModuleHolder(std::nullptr_t) : impl_(nullptr) {} /// Constructs the `ModuleHolder` with a contained module, forwarding all /// arguments to its constructor. template < typename Head, typename... Tail, typename = typename std::enable_if< !(torch::detail::is_module_holder_of::value && (sizeof...(Tail) == 0))>::type> explicit ModuleHolder(Head&& head, Tail&&... tail) : impl_(new Contained( std::forward(head), std::forward(tail)...)) {} /// Constructs the `ModuleHolder` from a pointer to the contained type. /// Example: `Linear(std::make_shared(...))`. /* implicit */ ModuleHolder(std::shared_ptr module) : impl_(std::move(module)) {} /// Returns true if the `ModuleHolder` contains a module, or false if it is /// `nullptr`. explicit operator bool() const noexcept { return !is_empty(); } /// Forwards to the contained module. Contained* operator->() { return get(); } /// Forwards to the contained module. const Contained* operator->() const { return get(); } /// Returns a reference to the contained module. Contained& operator*() { return *get(); } /// Returns a const reference to the contained module. const Contained& operator*() const { return *get(); } /// Returns a shared pointer to the underlying module. const std::shared_ptr& ptr() const { TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); return impl_; } /// Returns a pointer to the underlying module. Contained* get() { TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); return impl_.get(); } /// Returns a const pointer to the underlying module. const Contained* get() const { TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); return impl_.get(); } /// Calls the `forward()` method of the contained module. template auto operator()(Args&&... args) -> torch::detail::return_type_of_forward_t { // This will not compile if the module does not have a `forward()` method // (as expected). // NOTE: `std::forward` is qualified to prevent VS2017 emitting // error C2872: 'std': ambiguous symbol return impl_->forward(::std::forward(args)...); } /// Forwards to the subscript operator of the contained module. /// NOTE: std::forward is qualified to prevent VS2017 emitting /// error C2872: 'std': ambiguous symbol template auto operator[](Arg&& arg) -> decltype((*impl_)[::std::forward(arg)]) { return (*impl_)[::std::forward(arg)]; } /// Returns true if the `ModuleHolder` does not contain a module. bool is_empty() const noexcept { return impl_ == nullptr; } private: /// In C++17, the two methods below could be written as the following: /// if constexpr (std::is_default_constructible_v) { /// return std::make_shared(); /// } else { /// return nullptr; /// } /// In C++11, we use SFINAE instead of `if constexpr`. template < typename T = Contained, typename = torch::enable_if_t::value>> std::shared_ptr default_construct() { return std::make_shared(); } template torch::disable_if_t< std::is_default_constructible::value, std::shared_ptr> default_construct() { return nullptr; } }; /// Pretty prints the given `Module` into the `ostream`. template std::ostream& operator<<( std::ostream& stream, const nn::ModuleHolder& module) { return stream << *module; } /// Serializes a `ModuleHolder` into an `OutputArchive`. template serialize::OutputArchive& operator<<( serialize::OutputArchive& archive, const nn::ModuleHolder& module) { return archive << module.ptr(); } /// Deserializes a `ModuleHolder` from an `InputArchive`. template serialize::InputArchive& operator>>( serialize::InputArchive& archive, nn::ModuleHolder& module) { return archive >> module.ptr(); } } // namespace nn } // namespace torch /// Defines a class `Name` which inherits from `nn::ModuleHolder` to provide a /// wrapper over a `std::shared_ptr`. #define TORCH_MODULE_IMPL(Name, Impl) \ class Name : public torch::nn::ModuleHolder { /* NOLINT */ \ public: \ using torch::nn::ModuleHolder::ModuleHolder; \ } /// Like `TORCH_MODULE_IMPL`, but defaults the `Impl` name to `Impl`. #define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl)