#pragma once /** * This file contains functionality to take a C++ function and infer its * c10::FunctionSchema. */ #include #include #include namespace c10 { namespace detail { namespace infer_schema { /// The templated inference code creates `ArgumentDef` instead of `Argument`, /// because that can be constructed at compile time and has a much smaller /// binary size than having calls to `Argument` constructors in the template. /// Creating `Argument` objects from `ArgumentDef` can then be done at /// runtime in a non-templated way. struct ArgumentDef final { using GetTypeFn = TypePtr(); GetTypeFn* getTypeFn; }; template struct bool_t {}; template<> struct bool_t : std::true_type {}; template<> struct bool_t : std::false_type {}; /// Checks the static C++ types `Types` for correctness to catch common error cases. template constexpr int checkStaticTypes() { // Give nice error messages for some of the common error cases. // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT static_assert(guts::conjunction< bool_t::value || std::is_same::value || std::is_same::value>... >::value, "INVALID TYPE: Only int64_t and bool are supported as an integral argument type"); static_assert(guts::conjunction< bool_t::value>... >::value, "INVALID TYPE: float is not supported as an argument type, use double instead"); return 0; } template constexpr std::array createArgumentVectorFromTypes(guts::index_sequence) { return ( // Check types for common errors checkStaticTypes(), // Create the return value std::array{{ArgumentDef{&getTypePtr_>::call}...}} ); } /// Creates a vector of `ArgumentDef` from a list of C++ types that are specified /// as template arguments. template struct createArguments final {}; template struct createArguments> final { static constexpr std::array call() { return createArgumentVectorFromTypes( guts::make_index_sequence() ); } }; /// Creates a vector of `ArgumentDef` from a list of C++ types that are specified /// as a tuple (i.e. in the way c10 kernels return values). /// It can be a tuple if there's three output arguments with types A, B, C. /// It can be an empty tuple<>, or void for kernels that don't return anything. /// It can be a single type A (i.e. no tuple) for the case where a kernel just /// returns one value. template struct createReturns final {}; template struct createReturns, void> final { static constexpr std::array call() { return createArgumentVectorFromTypes( guts::make_index_sequence() ); } }; template struct createReturns::value && !guts::is_instantiation_of::value>> final { static constexpr std::array call() { return createReturns>::call(); } }; template<> struct createReturns final { static constexpr std::array call() { return createReturns>::call(); } }; template std::vector createArgumentVector(const std::array& args) { std::vector result; result.reserve(NumArgs); for (size_t i = 0; i < args.size(); ++i) { // Arguments are named "_" result.push_back(Argument("_" + c10::guts::to_string(i), (*args[i].getTypeFn)())); } return result; } // This is intentionally a separate function // because then the template is smaller and that benefits binary size inline FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, std::vector&& arguments, std::vector&& returns) { return FunctionSchema(std::move(name), std::move(overload_name), std::move(arguments), std::move(returns)); } template inline FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, const std::array& arguments, const std::array& returns) { return make_function_schema(std::move(name), std::move(overload_name), createArgumentVector(arguments), createArgumentVector(returns)); } /// Creates a `FunctionSchema` object from a `FunctionTraits` type for a /// function. template FunctionSchema createFunctionSchemaFromTraits(std::string&& name, std::string&& overload_name) { using ReturnType = typename FunctionTraits::return_type; using ParameterTypes = typename FunctionTraits::parameter_types; constexpr auto arguments = createArguments::call(); constexpr auto returns = createReturns::call(); return make_function_schema(std::move(name), std::move(overload_name), arguments, returns); } } } template FunctionSchema inferFunctionSchema(std::string&& name, std::string&& overload_name) { return detail::infer_schema::createFunctionSchemaFromTraits>(std::move(name), std::move(overload_name)); } CAFFE2_API c10::optional findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified); }