#pragma once #include #include #include #include namespace c10 { namespace impl { class OperatorEntry; } namespace impl { // This is a private class used inside the Dispatcher to represent an operator // and its dispatch table. This is not part of the public API. class OperatorEntry final { public: explicit OperatorEntry(FunctionSchema&& schema, OperatorOptions&& options); OperatorEntry(const OperatorEntry&) = delete; OperatorEntry(OperatorEntry&&) noexcept = delete; OperatorEntry& operator=(const OperatorEntry&) = delete; OperatorEntry& operator=(OperatorEntry&&) noexcept = delete; const FunctionSchema& schema() const { return schema_; } template Return callUnboxed(TensorTypeId dispatchKey, Args... args) const { // TODO Remove dispatchKey argument and instead infer dispatchKey from args... #if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && __GNUC__ < 5 // GCC 4 has issues with parameter packs inside lambdas, let's instead // return the KernelFunction from the lambda. Note: This copies the // KernelFunction and is slow, but it's the only way to make it work for // GCC 4. KernelFunction func = dispatchTable_.read([&] (const DispatchTable& dispatchTable) -> KernelFunction { return dispatchTable.lookup(dispatchKey); }); return func.callUnboxed(std::forward(args)...); #else // For all other compilers and newer GCC, let's do it right return dispatchTable_.read([&] (const DispatchTable& dispatchTable) -> Return { return dispatchTable.lookup(dispatchKey) .callUnboxed(std::forward(args)...); }); #endif } template Return callUnboxedOnly(TensorTypeId dispatchKey, Args... args) const { // TODO Remove dispatchKey argument and instead infer dispatchKey from args... #if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && __GNUC__ < 5 // GCC 4 has issues with parameter packs inside lambdas, let's instead // return the KernelFunction from the lambda. Note: This copies the // KernelFunction and is slow, but it's the only way to make it work for // GCC 4. KernelFunction func = dispatchTable_.read([&] (const DispatchTable& dispatchTable) -> KernelFunction { return dispatchTable.lookup(dispatchKey); }); return func.callUnboxedOnly(std::forward(args)...); #else // For all other compilers and newer GCC, let's do it right return dispatchTable_.read([&] (const DispatchTable& dispatchTable) -> Return { return dispatchTable.lookup(dispatchKey) .callUnboxedOnly(std::forward(args)...); }); #endif } void callBoxed(Stack* stack) const { return dispatchTable_.read([&] (const DispatchTable& dispatchTable) { dispatchTable.lookup(stack).callBoxed(stack); }); } void prepareForDeregistration(); RegistrationHandleRAII registerKernel(TensorTypeId dispatch_key, KernelFunction kernel); RegistrationHandleRAII registerCatchallKernel(KernelFunction kernel); const OperatorOptions& options() { return options_; } private: void deregisterKernel_(TensorTypeId dispatch_key, std::list::iterator kernel); void deregisterCatchallKernel_(std::list::iterator kernel); FunctionSchema schema_; // The dispatchTable stores the current kernel for each dispatch key LeftRight dispatchTable_; // kernels_ stores all registered kernels for the corresponding dispatch key // and catchAllKernels_ stores the catch-all kernels. // If an operator library gets loaded that overwrites an already existing kernel, // both kernels will be in that list but only the newer one will be in // dispatchTable. If any of the kernels go away (say the library gets // unloaded), we remove the kernel from this list and update the // dispatchTable if necessary. // Kernels in the list are ordered by registration time descendingly, // newer registrations are before older registrations. // We do not combine dispatchTable and kernels into one hash map because // kernels is a larger data structure and accessed quite infrequently // while dispatchTable is accessed often and should be kept small to fit // into CPU caches. // Invariants: // - dispatchTable[dispatch_key] == kernels_[dispatch_key].front() // - dispatchTable[dispatch_key] does not exist if and only if // kernels_[dispatch_key] does not exist // - If kernels_[dispatch_key] exists, then it has elements. // It is never an empty list. // Analogous invariants for catchAllKernels_. // // Why do we do that? // ----- // We mostly do this to enable Jupyter notebooks where a cell registering // a kernel could be executed multiple times and the later execution // should overwrite the earlier one. Note that this still fails when the // function schema changed between the executions, but it works as long // as the function schema didn't change. A better solution would be to // unload the old extension library from the Jupyter cell when the cell is // re-executed and then only allow one kernel here, i.e. error if a kernel // is already registered, but that's a lot of effort to implement and // currently not high-pri. ska::flat_hash_map> kernels_; std::list catchAllKernels_; // Some metadata about the operator OperatorOptions options_; std::mutex kernelsMutex_; // protects kernels_ // This function re-establishes the invariant that dispatchTable // contains the front element from the kernels list for a given dispatch key. void updateDispatchTable_(TensorTypeId dispatch_key); void updateCatchallDispatchTable_(); }; } }