#pragma once #include #include #include #include namespace torch { namespace nn { class Module; } // namespace nn } // namespace torch namespace torch { namespace detail { /// Detects if a type T has a forward() method. template struct has_forward { // Declare two types with differing size. using yes = int8_t; using no = int16_t; // Here we declare two functions. The first is only enabled if `&U::forward` // is well-formed and returns the `yes` type. In C++, the ellipsis parameter // type (`...`) always puts the function at the bottom of overload resolution. // This is specified in the standard as: 1) A standard conversion sequence is // always better than a user-defined conversion sequence or an ellipsis // conversion sequence. 2) A user-defined conversion sequence is always better // than an ellipsis conversion sequence This means that if the first overload // is viable, it will be preferred over the second as long as we pass any // convertible type. The type of `&U::forward` is a pointer type, so we can // pass e.g. 0. template static yes test(decltype(&U::forward)); template static no test(...); // Finally we test statically whether the size of the type returned by the // selected overload is the size of the `yes` type. static constexpr bool value = (sizeof(test(nullptr)) == sizeof(yes)); }; template constexpr bool check_not_lvalue_references() { return (!std::is_lvalue_reference::value || std::is_const::type>::value) && check_not_lvalue_references(); } template <> inline constexpr bool check_not_lvalue_references() { return true; } /// A type trait whose `value` member is true if `M` derives from `Module`. template using is_module = std::is_base_of::type>; template using enable_if_module_t = typename std::enable_if::value, T>::type; } // namespace detail } // namespace torch