#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace nn { /// Stores a type erased `Module`. /// /// The PyTorch C++ API does not impose an interface on the signature of /// `forward()` in `Module` subclasses. This gives you complete freedom to /// design your `forward()` methods to your liking. However, this also means /// there is no unified base type you could store in order to call `forward()` /// polymorphically for any module. This is where the `AnyModule` comes in. /// Instead of inheritance, it relies on type erasure for polymorphism. /// /// An `AnyModule` can store any `nn::Module` subclass that provides a /// `forward()` method. This `forward()` may accept any types and return any /// type. Once stored in an `AnyModule`, you can invoke the underlying module's /// `forward()` by calling `AnyModule::forward()` with the arguments you would /// supply to the stored module (though see one important limitation below). /// Example: /// /// \rst /// .. code-block:: cpp /// /// struct GenericTrainer { /// torch::nn::AnyModule module; /// /// void train(torch::Tensor input) { /// module.forward(input); /// } /// }; /// /// GenericTrainer trainer1{torch::nn::Linear(3, 4)}; /// GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)}; /// \endrst /// /// As `AnyModule` erases the static type of the stored module (and its /// `forward()` method) to achieve polymorphism, type checking of arguments is /// moved to runtime. That is, passing an argument with an incorrect type to an /// `AnyModule` will compile, but throw an exception at runtime: /// /// \rst /// .. code-block:: cpp /// /// torch::nn::AnyModule module(torch::nn::Linear(3, 4)); /// // Linear takes a tensor as input, but we are passing an integer. /// // This will compile, but throw a `torch::Error` exception at runtime. /// module.forward(123); /// \endrst /// /// \rst /// .. attention:: /// One noteworthy limitation of `AnyModule` is that its `forward()` method /// does not support implicit conversion of argument types. For example, if /// the stored module's `forward()` method accepts a `float` and you call /// `any_module.forward(3.4)` (where `3.4` is a `double`), this will throw /// an exception. /// \endrst /// /// The return type of the `AnyModule`'s `forward()` method is controlled via /// the first template argument to `AnyModule::forward()`. It defaults to /// `torch::Tensor`. To change it, you can write `any_module.forward()`, /// for example. /// /// \rst /// .. code-block:: cpp /// /// torch::nn::AnyModule module(torch::nn::Linear(3, 4)); /// auto output = module.forward(torch::ones({2, 3})); /// /// struct IntModule { /// int forward(int x) { return x; } /// }; /// torch::nn::AnyModule module(IntModule{}); /// int output = module.forward(5); /// \endrst /// /// The only other method an `AnyModule` provides access to on the stored /// module is `clone()`. However, you may acquire a handle on the module via /// `.ptr()`, which returns a `shared_ptr`. Further, if you know /// the concrete type of the stored module, you can get a concrete handle to it /// using `.get()` where `T` is the concrete module type. /// /// \rst /// .. code-block:: cpp /// /// torch::nn::AnyModule module(torch::nn::Linear(3, 4)); /// std::shared_ptr ptr = module.ptr(); /// torch::nn::Linear linear(module.get()); /// \endrst class AnyModule { public: /// A type-erased value. class Value; /// A default-constructed `AnyModule` is in an empty state. AnyModule() = default; /// Constructs an `AnyModule` from a `shared_ptr` to concrete module object. template explicit AnyModule(std::shared_ptr module); /// Constructs an `AnyModule` from a concrete module object. template < typename ModuleType, typename = torch::detail::enable_if_module_t> explicit AnyModule(ModuleType&& module); /// Constructs an `AnyModule` from a module holder. template explicit AnyModule(const ModuleHolder& module_holder); /// Move construction and assignment is allowed, and follows the default /// behavior of move for `std::unique_ptr`. AnyModule(AnyModule&&) = default; AnyModule& operator=(AnyModule&&) = default; /// Creates a shallow copy of an `AnyModule`. AnyModule(const AnyModule& other); AnyModule& operator=(const AnyModule& other); /// Creates a deep copy of an `AnyModule` if it contains a module, else an /// empty `AnyModule` if it is empty. AnyModule clone(optional device = nullopt) const; /// Assigns a module to the `AnyModule` (to circumvent the explicit /// constructor). template AnyModule& operator=(std::shared_ptr module); /// Invokes `forward()` on the contained module with the given arguments, and /// returns the return value as an `Value`. Use this method when chaining /// `AnyModule`s in a loop. template Value any_forward(ArgumentTypes&&... arguments); /// Invokes `forward()` on the contained module with the given arguments, and /// casts the returned `Value` to the supplied `ReturnType` (which defaults to /// `torch::Tensor`). template ReturnType forward(ArgumentTypes&&... arguments); /// Attempts to cast the underlying module to the given module type. Throws an /// exception if the types do not match. template > T& get(); /// Attempts to cast the underlying module to the given module type. Throws an /// exception if the types do not match. template > const T& get() const; /// Returns the contained module in a `nn::ModuleHolder` subclass if possible /// (i.e. if `T` has a constructor for the underlying module type). template T get() const; /// Returns a `std::shared_ptr` whose dynamic type is that of the underlying /// module. std::shared_ptr ptr() const; /// Like `ptr()`, but casts the pointer to the given type. template > std::shared_ptr ptr() const; /// Returns the `type_info` object of the contained value. const std::type_info& type_info() const; /// Returns true if the `AnyModule` does not contain a module. bool is_empty() const noexcept; private: /// \internal /// The static type of the object we store in the `AnyModule`, which erases /// the actual type, but allows us to call `forward()` on the underlying /// module. struct Placeholder; /// \internal /// The dynamic type of the object stored in the `AnyModule`. It contains the /// concrete instance to which all calls are forwarded. It is parameterized /// over the concrete type of the module, and the types of the arguments the /// module takes in its `forward()` method. template struct Holder; /// Creates a `unique_ptr` pointing to a `Holder` of the correct /// type. This method is used to deduce the arguments of the module's /// `forward()` method. template < typename ModuleType, typename Class, typename ReturnType, typename... ArgumentTypes> std::unique_ptr make_holder( std::shared_ptr&& module, ReturnType (Class::*)(ArgumentTypes...)); /// Helper method invoked by const and non-const `get()`. template ModuleType& get_(ReturnType (ModuleType::*)(ArgumentTypes...)) const; /// Helper method invoked by const and non-const `get()`. template ModuleType& get_() const; /// The type erased module. std::unique_ptr content_; }; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule::Value ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /// A simplified implementation of `std::any` which stores /// a type erased object, whose concrete value can be retrieved at runtime by /// checking if the `typeid()` of a requested type matches the `typeid()` of /// the object stored. It is simplified in that it does not handle copying, as /// we do not require it for our use cases. Moves are sufficient. class AnyModule::Value { public: /// Move construction and assignment is allowed, and follows the default /// behavior of move for `std::unique_ptr`. Value(Value&&) = default; Value& operator=(Value&&) = default; /// Copy is disallowed, because we don't need it. Value(const Value& other) = delete; Value& operator=(const Value& other) = delete; /// Returns a pointer to the value contained in the `Value` if the type passed /// as template parameter matches the type of the value stored, and returns a /// null pointer otherwise. template T* try_get() { static_assert( !std::is_reference::value, "Value stores decayed types, you cannot cast it to a reference type"); static_assert( !std::is_array::value, "Value stores decayed types, you must cast it to T* instead of T[]"); if (typeid(T).hash_code() == type_info().hash_code()) { return &static_cast&>(*content_).value; } return nullptr; } /// Returns the value contained in the `Value` if the type passed as template /// parameter matches the type of the value stored, and throws an exception /// otherwise. template T get() { if (auto* maybe_value = try_get()) { return *maybe_value; } AT_ERROR( "Attempted to cast Value to ", c10::demangle(typeid(T).name()), ", but its actual type is ", c10::demangle(type_info().name())); } /// Returns the `type_info` object of the contained value. const std::type_info& type_info() const noexcept { return content_->type_info; } private: friend class AnyModule; friend struct TestValue; /// Constructs the `Value` from value type. template < typename T, typename = torch::disable_if_t::value>> explicit Value(T&& value) : content_( torch::make_unique>>(std::forward(value))) {} /// Constructs the `Value` from an `autograd::Variable`, first converting it /// to a `torch::Tensor`. explicit Value(autograd::Variable variable) : Value(Tensor(std::move(variable))) {} /// \internal /// The static type of the object we store in the `Value`, which erases the /// actual object's type, allowing us only to check the `type_info` of the /// type stored in the dynamic type. struct Placeholder { explicit Placeholder(const std::type_info& type_info_) noexcept : type_info(type_info_) {} virtual ~Placeholder() = default; const std::type_info& type_info; }; /// \internal /// The dynamic type of the object we store in the `Value`, which hides the /// actual object we have erased in this `Value`. template struct Holder : public Placeholder { /// A template because T&& would not be universal reference here. template explicit Holder(U&& value_) noexcept : Placeholder(typeid(T)), value(std::forward(value_)) {} T value; }; /// The type erased object. std::unique_ptr content_; }; // ~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule::Placeholder ~~~~~~~~~~~~~~~~~~~~~~~~~~ struct AnyModule::Placeholder : public AnyModule::Value::Placeholder { using AnyModule::Value::Placeholder::Placeholder; /// The "erased" `forward()` method. virtual Value forward(std::vector&& arguments) = 0; /// Returns std::shared_ptr pointing to the erased module. virtual std::shared_ptr ptr() = 0; /// Returns a `Placeholder` with a shallow copy of this `AnyModule`. virtual std::unique_ptr copy() const = 0; /// Returns a `Placeholder` with a deep copy of this `AnyModule`. virtual std::unique_ptr clone(optional device) const = 0; }; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule::Holder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template struct AnyModule::Holder : public AnyModule::Placeholder { /// \internal struct CheckedGetter { template decay_t&& operator()(size_t index) { AT_ASSERT(index < arguments_.size()); auto& value = arguments_[index]; if (auto* maybe_value = value.template try_get>()) { return std::move(*maybe_value); } AT_ERROR( "Expected argument #", index, " to be of type ", c10::demangle(typeid(T).name()), ", but received value of type ", c10::demangle(value.type_info().name())); } std::vector& arguments_; }; /// \internal struct InvokeForward { template Value operator()(Ts&&... ts) { return Value(module_->forward(std::forward(ts)...)); } std::shared_ptr& module_; }; /// Constructs the `Holder` from a concrete module. explicit Holder(std::shared_ptr&& module_) : Placeholder(typeid(ModuleType)), module(std::move(module_)) {} /// Calls `forward()` on the underlying module, casting each `Value` in the /// argument vector to a concrete value. Value forward(std::vector&& arguments) override { TORCH_CHECK( arguments.size() == sizeof...(ArgumentTypes), c10::demangle(type_info.name()), "'s forward() method expects ", sizeof...(ArgumentTypes), " arguments, but received ", arguments.size()); // FYI: During invocation of a module's `forward()` method, the values live // in the `arguments` vector inside this function. return torch::unpack( InvokeForward{module}, CheckedGetter{arguments}); } std::shared_ptr ptr() override { return module; } std::unique_ptr copy() const override { return torch::make_unique(*this); } std::unique_ptr clone(optional device) const override { return torch::make_unique( std::dynamic_pointer_cast(module->clone(device))); } /// The actual concrete module instance. std::shared_ptr module; }; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template AnyModule::AnyModule(std::shared_ptr module) : content_(make_holder( std::move(module), &std::remove_reference::type::forward)) { // `AnyModule` can only store an `nn::Module` subclass object that provides // a `forward()` method that has a non-templatized return type. // (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s // `forward()` method has a templatized return type.) static_assert( torch::detail::is_module::value, "Can only store object derived from nn::Module into AnyModule"); static_assert( torch::detail::has_forward::value, "Can only store module with a forward() method that has a non-templatized" " argument type and return type into AnyModule (e.g. we cannot store nn::Sequential" "into AnyModule, because its forward() method's argument type and return type are templatized." " If you need to use nn::Sequentials inside each other you can subclass " "nn::Sequential and write a non-templatized forward function for it. You can checkout " "https://github.com/pytorch/vision/blob/2f46070f3cb1ea894d82578f3dc5677f82f34958/torchvision/csrc/models/mnasnet.cpp#L59 " "for an example on how to do this.)."); } template AnyModule::AnyModule(ModuleType&& module) : AnyModule( std::make_shared(std::forward(module))) {} template AnyModule::AnyModule(const ModuleHolder& module_holder) : AnyModule(module_holder.ptr()) {} inline AnyModule::AnyModule(const AnyModule& other) : content_(other.content_ ? other.content_->copy() : nullptr) {} inline AnyModule& AnyModule::operator=(const AnyModule& other) { if (this != &other) { content_ = other.content_ ? other.content_->copy() : nullptr; } return *this; } inline AnyModule AnyModule::clone(optional device) const { AnyModule clone; clone.content_ = content_ ? content_->clone(device) : nullptr; return clone; } template AnyModule& AnyModule::operator=(std::shared_ptr module) { return (*this = AnyModule(std::move(module))); } template AnyModule::Value AnyModule::any_forward(ArgumentTypes&&... arguments) { TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule"); std::vector values; values.reserve(sizeof...(ArgumentTypes)); torch::apply( [&values](Value&& value) { values.push_back(std::move(value)); }, Value(std::forward(arguments))...); return content_->forward(std::move(values)); } template ReturnType AnyModule::forward(ArgumentTypes&&... arguments) { return any_forward(std::forward(arguments)...) .template get(); } template T& AnyModule::get() { TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule"); return get_(); } template const T& AnyModule::get() const { TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule"); return get_(); } template T AnyModule::get() const { return T(ptr()); } inline std::shared_ptr AnyModule::ptr() const { TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule"); return content_->ptr(); } template std::shared_ptr AnyModule::ptr() const { TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule"); // Call get() but discard the value, just to do the type checking. get_(); return std::dynamic_pointer_cast(ptr()); } inline const std::type_info& AnyModule::type_info() const { TORCH_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule"); return content_->type_info; } inline bool AnyModule::is_empty() const noexcept { return content_ == nullptr; } // Private Methods template < typename ModuleType, typename Class, typename ReturnType, typename... ArgumentTypes> std::unique_ptr AnyModule::make_holder( std::shared_ptr&& module, ReturnType (Class::*)(ArgumentTypes...)) { static_assert( torch::detail::check_not_lvalue_references(), "Modules stored inside AnyModule must not take references. " "Use pointers instead."); static_assert( !std::is_void::value, "AnyModule cannot store modules that return void " "(you can return a dummy value)."); return torch::make_unique, ArgumentTypes...>>( std::move(module)); } template ModuleType& AnyModule::get_() const { using M = typename std::remove_reference::type; static_assert( torch::detail::has_forward::value, "Can only call AnyModule::get with a type T that has a forward method"); return get_(&M::forward); } template ModuleType& AnyModule::get_( ReturnType (ModuleType::*)(ArgumentTypes...)) const { if (typeid(ModuleType).hash_code() == type_info().hash_code()) { return *static_cast&>(*content_) .module; } AT_ERROR( "Attempted to cast module of type ", c10::demangle(type_info().name()), " to type ", c10::demangle(typeid(ModuleType).name())); } } // namespace nn } // namespace torch