#pragma once #include #include #include namespace torch { namespace jit { /// Registration class for new operators. Effectively calls /// `torch::jit::registerOperator` for every supplied operator, but allows doing /// so in the global scope when a `RegisterOperators` object is assigned to a /// static variable. Also handles registration of user-defined, "custom" /// operators. struct TORCH_API RegisterOperators { RegisterOperators() = default; /// Registers a vector of already created `Operator`s. RegisterOperators(std::vector operators) { for (Operator& o : operators) { registerOperator(std::move(o)); } } /// Calls `op(...)` with the given operator name and implementation. template C10_DEPRECATED_MESSAGE("torch::jit::RegisterOperators is deprecated. Please use torch::RegisterOperators instead.") RegisterOperators(const std::string& name, Implementation&& implementation) { op_(name, std::forward(implementation)); } template C10_DEPRECATED_MESSAGE("torch::jit::RegisterOperators is deprecated. Please use torch::RegisterOperators instead.") RegisterOperators& op( const std::string& name, Implementation&& implementation) { op_(name, std::forward(implementation)); return *this; } private: template void op_(const std::string& name, Implementation&& implementation) { registrars_.emplace_back(std::make_shared(name, std::forward(implementation))); } // A c10::RegisterOperators instance is not copyable, so to make // torch::jit::RegisterOperators copyable, we use shared_ptrs. // We need to keep the c10::RegisterOperators instances around // because this is an RAII pattern. In the destructor, the registered // ops get de-registered. std::vector> registrars_; }; } // namespace jit } // namespace torch