#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace py = pybind11; namespace torch { namespace jit { static std::vector registeredOps; namespace detail { template struct types { constexpr static bool hasRet = true; using type = types; }; template struct types { constexpr static bool hasRet = false; using type = types; }; template struct args; template struct args : types {}; template using args_t = typename args::type; } // namespace detail template detail::types init() { return detail::types{}; } // To bind custom classes into Torchscript, use an API very similar to Pybind's. // Currently exposes one class `torch::jit::class_` and 2 methods. // - Constructing `torch::jit::class_` registers `Foo` in Python and // Torchscript, and puts it under `torch.classes.Foo` in Python. // - torch::jit::class_.def("method1", &Foo::method1) does some template // metaprogramming to introspect the function types and register the operator // for use in Torchscript. // - torch::jit::class_.def(torch::jit::init()) registers // the Foo(int, int) constructor. // see test/custom_operator/classes.cpp and // test/custom_operator/test_custom_classes.py for example usages template class class_ { std::string className; std::string qualClassName; c10::optional> pyClass = c10::nullopt; std::shared_ptr classCu = nullptr; ClassTypePtr classTypePtr; const std::string parentModule = "classes"; const std::string topModule = "__torch__.torch"; public: class_(string className_) : className(std::move(className_)) { // Currently we register everything as a python class just for convenience. // We'll want to remove this at some point to get rid of the python // dependency. It would require significant changes to class registration, // (I think)? qualClassName = topModule + "." + parentModule + "." + className; auto obj = py::module::import("torch").attr(parentModule.c_str()); pyClass = py::class_(obj, className.c_str()); pyClass->attr("qualified_name") = py::str(qualClassName); auto newClass = py::module::import("torch.jit") .attr("_add_script_class")(*pyClass, qualClassName.c_str()); auto castToPython = [](void* objPtr) -> PyObject* { CurClass x = *static_cast(objPtr); auto py_object = py::cast(x); PyObject* rawPyObj = py_object.release().ptr(); return rawPyObj; }; getClassConverter()[qualClassName] = castToPython; // We currently represent custom classes as torchscript classes with a // capsule attribute classCu = torch::jit::get_python_cu(); classTypePtr = ClassType::create(c10::QualifiedName(qualClassName), classCu); classTypePtr->addAttribute("capsule", CapsuleType::get()); c10::getCustomClassTypeMap().insert({typeid(c10::intrusive_ptr).name(), StrongTypePtr(classCu, classTypePtr)}); c10::getCustomClassTypeMap().insert({typeid(c10::tagged_capsule).name(), StrongTypePtr(classCu, classTypePtr)}); classCu->register_type(classTypePtr); } template class_& def(detail::types) { // Used in combination with // torch::jit::init<...>() pyClass->def(py::init()); auto func = [](c10::tagged_capsule self, Types... args) { auto classObj = c10::make_intrusive(args...); auto genericPtr = c10::static_intrusive_pointer_cast(classObj); auto capsule = IValue(genericPtr); auto object = self.ivalue.toObject(); object->setSlot(0, capsule); }; defineMethod("__init__", std::move(func), false); return *this; } template class_& def(string name, Func f) { auto res = def_(name, f, detail::args_t{}); return *this; } private: template struct addInput { static Value* call(std::shared_ptr graph) { return graph->addInput()->setType(getTypePtr()); } }; template std::vector addInputs_( Func f, std::shared_ptr graph, guts::index_sequence) { using argTypes = typename guts::infer_function_traits_t::parameter_types; std::vector res = { addInput>::call( graph)...}; return res; } template std::vector addInputs(Func f, std::shared_ptr graph) { constexpr auto numArgs = guts::infer_function_traits_t::number_of_parameters; return addInputs_(f, graph, guts::make_index_sequence()); } template std::string type_name() { return std::string(typeid(Last).name()); } template std::string type_name() { return type_name() + "_" + type_name(); } template void addType(Value* v) { v->setType(getTypePtr()); } template void defineMethod(std::string name, Func func, bool hasRet) { auto graph = std::make_shared(); auto qualFuncName = className + "::" + name; registeredOps.push_back( torch::RegisterOperators().op(qualFuncName, std::move(func))); std::vector inputs = addInputs(func, graph); auto methodCall = graph->insertNode(graph->create( Symbol::fromQualString(qualFuncName), inputs, hasRet)); Value* res; if (hasRet) { res = methodCall->output(); addType(res); } else { res = graph->insertConstant(IValue())->setType(NoneType::get()); } graph->registerOutput(res); auto method = classCu->create_function(qualClassName + "." + name, graph); classTypePtr->addMethod(method); } template class_& def_(string name, Func f, detail::types funcInfo) { pyClass->def(name.c_str(), f); auto func = [f](c10::intrusive_ptr cur, Types... args) { return guts::invoke(f, *cur, args...); }; defineMethod(name, std::move(func), funcInfo.hasRet); return *this; } }; } // namespace jit } // namespace torch