#pragma once #include #include namespace c10 { using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace. /** * Inherit from OperatorKernel to implement a c10 kernel. * * Example: * > namespace { * > class my_kernel_cpu final : public c10::OperatorKernel { * > public: * > Tensor operator()(Tensor a, Tensor b) {...} * > }; * > } * * The kernel class is allowed to have members but these are equivalent * to global variables. The kernel implementation is responsible for * preventing race conditions on them. * * See below for how to register this kernel with PyTorch. */ struct CAFFE2_API OperatorKernel { virtual ~OperatorKernel() = default; }; namespace detail { // supported_primitive_arg_types defines which primitive types we allow in // kernel functions as arguments or returns. // Additionally, we support lists, dicts and optionals containing these types. using supported_primitive_arg_types = guts::typelist::typelist< int64_t, double, bool, std::string, at::Tensor, at::Scalar >; template struct assert_is_valid_input_type { assert_is_valid_input_type() { auto tmap = c10::getCustomClassTypeMap(); TORCH_CHECK(c10::isCustomClassRegistered(), "Tried to use undefined class as input argument"); } }; template struct assert_is_valid_input_type::value>> { // everything is ok, this is a primitive type }; template struct assert_is_valid_input_type, AllowDeprecatedTypes> : assert_is_valid_input_type {}; template struct assert_is_valid_input_type, AllowDeprecatedTypes> : assert_is_valid_input_type { static_assert(guts::typelist::contains::value, "You tried to register a kernel with an unsupported input type: Dict where Key is invalid. We only support int64_t, double, bool, and string."); }; template struct assert_is_valid_input_type, AllowDeprecatedTypes> : assert_is_valid_input_type { static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported input type: std::unordered_map. Please use Dict instead."); static_assert(guts::typelist::contains::value, "You tried to register a kernel with an unsupported input type: std::unordered_map where Key is invalid. We only support int64_t, double, bool, and string."); }; template struct assert_is_valid_input_type, AllowDeprecatedTypes> : assert_is_valid_input_type { static_assert(!std::is_same::value, "You tried to register a kernel with an unsupported input type: List. Please use List, List or Tensor instead."); }; template struct assert_is_valid_input_type, AllowDeprecatedTypes> : assert_is_valid_input_type { static_assert(!std::is_same::value, "You tried to register a kernel with an unsupported input type: std::vector. Please use List, List or Tensor instead."); // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported input type: std::vector. Please use List instead."); }; // The following specialisations of assert_is_valid_input_type are technically not // necessary since we would hit the base case and show an error message // there if they didn't exist, but we can show a better error message // in some common error scenarios. template struct assert_is_valid_input_type::value>> { // There is no reason to support float when we have double. Keep the API lean. static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported input type: float. Please use double instead."); }; template struct assert_is_valid_input_type::value>> { static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported input type: const char*. Please use std::string instead."); }; template struct assert_is_valid_input_type, T>::value>> { static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported input type: vector. Please use List instead."); }; template struct assert_is_valid_input_type::value && !guts::typelist::contains::value>> { static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported integral input type. Please use int64_t instead."); }; template struct assert_is_valid_output_type { assert_is_valid_output_type() { auto tmap = getCustomClassTypeMap(); TORCH_CHECK(c10::isCustomClassRegistered(), "Tried to use undefined class as output"); } }; template struct assert_is_valid_output_type::value>> { // everything is ok, this is a primitive type }; template struct assert_is_valid_output_type, AllowDeprecatedTypes> : assert_is_valid_output_type {}; template struct assert_is_valid_output_type, AllowDeprecatedTypes> : assert_is_valid_output_type { static_assert(guts::typelist::contains::value, "You tried to register a kernel with an unsupported output type: Dict where Key is invalid. We only support int64_t, double, bool, and string."); static_assert(!std::is_same::value, "You tried to register a kernel with an unsupported output type: Dict. Please use Dict or Dict."); }; template struct assert_is_valid_output_type, AllowDeprecatedTypes> : assert_is_valid_output_type { static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported output type: std::unordered_map. Please use Dict instead."); static_assert(guts::typelist::contains::value, "You tried to register a kernel with an unsupported output type: std::unordered_map where Key is invalid. We only support int64_t, double, bool, and string."); static_assert(!std::is_same::value, "You tried to register a kernel with an unsupported output type: std::unordered_map. Please use Dict or Dict."); }; template struct assert_is_valid_output_type, AllowDeprecatedTypes> : assert_is_valid_output_type { static_assert(!std::is_same::value, "You tried to register a kernel with an unsupported output type: List. Please use List, List or Tensor instead."); }; template struct assert_is_valid_output_type, AllowDeprecatedTypes> : assert_is_valid_output_type { static_assert(!std::is_same::value, "You tried to register a kernel with an unsupported output type: std::vector. Please use List, List or Tensor instead."); // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported output type: std::vector. Please use List instead."); }; // The following specialisations of assert_is_valid_output_type are technically not // necessary since we would hit the base case and show an error message // there if they didn't exist, but we can show a better error message // in some common error scenarios. template struct assert_is_valid_output_type::value>> { // There is no reason to support float when we have double. Keep the API lean. static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported output type: float. Please use double instead."); }; template struct assert_is_valid_output_type::value>> { static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported output type: const char*. Please use std::string instead."); }; template struct assert_is_valid_output_type, T>::value>> { static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported output type: vector. Please use List instead."); }; template struct assert_is_valid_output_type::value && !guts::typelist::contains::value>> { static_assert(guts::false_t::value, "You tried to register a kernel with an unsupported integral output type. Please use int64_t instead."); }; template T ivalue_to_arg(IValue&& v) { assert_is_valid_input_type(); return std::move(v).to(); } template IValue return_to_ivalue(T&& v) { assert_is_valid_output_type(); return c10::ivalue::from(v); } template typename guts::infer_function_traits_t::return_type call_functor_with_args_from_stack_(Functor* functor, Stack* stack, guts::index_sequence) { (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would be unused and we have to silence the compiler warning. constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices); using IValueArgTypes = typename guts::infer_function_traits_t::parameter_types; return (*functor)(ivalue_to_arg>>, AllowDeprecatedTypes>( std::move(torch::jit::peek(*stack, ivalue_arg_indices, num_ivalue_args)) )...); } template typename guts::infer_function_traits_t::return_type call_functor_with_args_from_stack(Functor* functor, Stack* stack) { constexpr size_t num_ivalue_args = guts::infer_function_traits_t::number_of_parameters; return call_functor_with_args_from_stack_(functor, stack, guts::make_index_sequence()); } template struct push_outputs final { static void call(OutputType&& output, Stack* stack) { torch::jit::push(*stack, return_to_ivalue(std::move(output))); } }; template struct push_outputs, AllowDeprecatedTypes> final { static void call(std::tuple&& output, Stack* stack) { call_(std::move(output), stack, guts::make_index_sequence()); } private: template static void call_(std::tuple&& output, Stack* stack, guts::index_sequence) { torch::jit::push(*stack, return_to_ivalue(std::move(std::get(output)))...); } }; template struct wrap_kernel_functor_boxed final {}; // SFINAE version for kernels that return an output template struct wrap_kernel_functor_boxed::return_type>::value>> final { static_assert(std::is_base_of::value, "Tried to register a kernel functor using the kernel() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); static void call(OperatorKernel* functor, Stack* stack) { constexpr size_t num_inputs = guts::infer_function_traits_t::number_of_parameters; KernelFunctor* functor_ = static_cast(functor); auto output = call_functor_with_args_from_stack(functor_, stack); torch::jit::drop(*stack, num_inputs); push_outputs::return_type, AllowDeprecatedTypes>::call(std::move(output), stack); } }; // SFINAE version for kernels that don't return an output template struct wrap_kernel_functor_boxed::return_type>::value>> final { static_assert(std::is_base_of::value, "Tried to register a kernel functor using the kernel() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); static void call(OperatorKernel* functor, Stack* stack) { constexpr size_t num_inputs = guts::infer_function_traits_t::number_of_parameters; KernelFunctor* functor_ = static_cast(functor); call_functor_with_args_from_stack(functor_, stack); torch::jit::pop(*stack, num_inputs); } }; template struct wrap_kernel_functor_unboxed_ final {}; template struct wrap_kernel_functor_unboxed_ final { static_assert(std::is_same::return_type>::value, "Return type mismatch"); static_assert(std::is_same, typename guts::infer_function_traits_t::parameter_types>::value, "Parameter types mismatch"); static ReturnType call(OperatorKernel* functor, ParameterTypes... args) { KernelFunctor* functor_ = static_cast(functor); return (*functor_)(std::forward(args)...); } }; template using wrap_kernel_functor_unboxed = wrap_kernel_functor_unboxed_::func_type>; template class KernelFactory final { static_assert(std::is_constructible::value, "Wrong argument types for constructor of kernel functor."); public: explicit constexpr KernelFactory(Args... args) : constructor_parameters_(std::move(args)...) {} std::unique_ptr operator()() const { return guts::apply( [] (const Args&... params) -> std::unique_ptr {return guts::make_unique_base(params...); }, constructor_parameters_); } private: std::tuple constructor_parameters_; }; template std::unique_ptr inferFunctionSchema_() { return guts::make_unique(inferFunctionSchema("", "")); } template class FunctionSchemaInferer final { public: using func_type = typename c10::guts::infer_function_traits_t::func_type; std::unique_ptr operator()() const { return inferFunctionSchema_(); } }; } } namespace torch { using OperatorKernel = c10::OperatorKernel; }