// This class exists only to do SFINAE on abstract types `T` that are really // `ModuleHolder`, because there's no good way to say that `T` is a // `ModuleHolder` over some unknown type `ModuleType`. With this, you can do // `enable_if_t>`. struct ModuleHolderIndicator {}; // A type trait that is true for types that are `ModuleHolder`s. template using is_module_holder = std::is_base_of>; template using disable_if_module_holder_t = disable_if_t::value>; // A collection of templates that answer the question whether a type `T` is a // `ModuleHolder`, and if so whether its contained type is of type `C`. This is // tricky because it is hard to short circuit in template metaprogramming. A // naive and incorrect solution to this problem would be something like // `disable_if::value && typename T::ContainedType == C>`. // This would disable all types that are not `ModuleHolder`s, because even // though the `is_module_holder::value` may be `false` for such types the // `T::ContainedType` access would be ill-formed and thus fail the whole // expression by the rules of SFINAE. Instead we have to use template // specialization to statically branch on the first condition // (`is_module_holder`) and are only then allowed to query // `T::ContainedType` in the branch for which the condition was true. // Base template. template struct is_module_holder_of_impl; // False branch. `T` is not a `ModuleHolder` and thus not a `ModuleHolder` with // contained type `C`. template struct is_module_holder_of_impl : std::false_type {}; // True branch. `T` is a `ModuleHolder` and thus we can legit access its // `ContainedType` and compare it against `C`. template struct is_module_holder_of_impl : std::is_same {}; // Helper template. template struct is_module_holder_of : is_module_holder_of_impl< is_module_holder::value, decay_t, decay_t> {}; // A collection of templates that allow deducing the return type of the // `forward()` method, but only if a module actually has a `forward()` method, // and otherwise deduces to the type `void`. template struct return_type_of_forward_impl; template struct return_type_of_forward_impl { using type = decltype(::std::declval().forward(::std::declval()...)); }; template struct return_type_of_forward_impl { using type = void; }; template using return_type_of_forward = return_type_of_forward_impl< torch::detail::has_forward::value, C, Args...>; template using return_type_of_forward_t = typename return_type_of_forward::type;