#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace c10 { namespace detail { class KernelTable_ final { public: void set(TensorTypeId key, const KernelFunction& value, const std::string& operator_name) { auto emplaced = map_.emplace(key, value); if (!emplaced.second) { // Element already existed. Overwrite it. emplaced.first->second = value; TORCH_WARN("Registered a kernel for operator ", operator_name," with dispatch key ", toString(key), " that overwrote a previously registered kernel with the same dispatch key for the same operator."); } } void removeIfExists(TensorTypeId key, const std::string& operator_name) { auto num_removed = map_.erase(key); TORCH_INTERNAL_ASSERT(num_removed <= 1); // This is not a multi-map } const KernelFunction* lookup(TensorTypeId key) const { auto found = map_.find(key); if (found != map_.end()) { return &found->second; } else { return nullptr; } } size_t size() const { return map_.size(); } std::string list_all_dispatch_keys() const { if (map_.size() == 0) { return "[]"; } std::ostringstream str; str << "[" << toString(map_.begin()->first); for (auto iter = ++map_.begin(); iter != map_.end(); ++iter) { str << ", " << toString(iter->first); } str << "]"; return str.str(); } private: ska::flat_hash_map map_; }; } // namespace detail /** * Per-operator dispatch table. * * Given an operator specified by a FunctionSchema, this class records a dispatch * table for various kernels provided for this operator. For example, if we * consider the operator add(Tensor, Tensor), the dispatch table for this * operator may contain implementations for various dynamic tensor types, such * as CPUTensorId, CUDATensorId, etc. */ class DispatchTable final { public: DispatchTable(const FunctionSchema& schema) : kernels_() , catchall_kernel_(c10::nullopt) , dispatch_strategy_(get_dispatch_strategy_(schema)) , operator_name_(schema.name()) {} /** * Register a kernel in the table at some dispatch key. * @param dispatch_key Dispatch key to define when this kernel is selected. * @param kernel Concrete kernel function implementation to register */ void setKernel( TensorTypeId dispatch_key, const KernelFunction& kernel) { TORCH_INTERNAL_ASSERT(dispatch_key != TensorTypeId::UndefinedTensorId); // The following assertion is disabled because we're codegenerating // autograd kernels for operators without tensor arguments even though // they are never called. These, however, register kernels for // VariableTensorId. // TODO Stop generating those kernels and re-enable this assertion here. //TORCH_CHECK(dispatch_strategy_.is_valid_, "Tried to register a kernel with dispatch key ", toString(dispatch_key), " for operator ", operator_name_, " that doesn't have tensor arguments."); kernels_.set(dispatch_key, kernel, operator_name_); } /** * Deregister the kernel for some dispatch key. * * @param dispatch_key Dispatch key to unregister. */ void removeKernelIfExists(TensorTypeId dispatch_key) { kernels_.removeIfExists(dispatch_key, operator_name_); } /** * Register a catch-all kernel that is called for this operator * independent of the inputs. An operator can have either * a catch-all kernel or a set of kernels with concrete * dispatch keys, not both. */ void setCatchallKernel(const KernelFunction& kernel) { if (catchall_kernel_.has_value()) { TORCH_WARN("Registered a catch-all kernel for operator ", operator_name_," that overwrote a previously registered catch-all kernel for the same operator."); } catchall_kernel_ = kernel; } /** * Remove the catch-all kernel. */ void removeCatchallKernel() { TORCH_INTERNAL_ASSERT(catchall_kernel_.has_value(), "Tried to remove the catch-all kernel for operator ", operator_name_," but there is no catch-all kernel registered."); catchall_kernel_ = c10::nullopt; } /** * Perform a dynamic dispatch on this table and find the kernel to call * for the given arguments. * * @param args Arguments to invoke the function with * @return Kernel function pointing to the right kernel for the given arguments. */ const KernelFunction& lookup(const Stack* stack) const { return lookup_([=] () -> c10::optional { if (!dispatch_strategy_.is_valid_) { return c10::nullopt; } return dispatch_strategy_.get_dispatch_key(stack, operator_name_); }); } const KernelFunction& lookup(TensorTypeId dispatchKey) const { return lookup_([=] () -> c10::optional { return dispatchKey;}); } bool isEmpty() const { return !catchall_kernel_.has_value() && kernels_.size() == 0; } std::string listAllDispatchKeys() const { std::string result = kernels_.list_all_dispatch_keys(); if (catchall_kernel_.has_value()) { result += ", CATCH-ALL"; } return result; } private: struct DispatchStrategy final { // this is caching the index so we don't have to parse the schema inputs // again and again for each dispatcher lookup. // num_args_ is allowed to be zero; that just means you must do the // fallthrough // TODO: a potential optimization is to store a bitfield of arg locations, size_t num_args_; // An invalid dispatch strategy means we can't dispatch any kernels. // You're able to create a dispatch table with an invalid dispatch strategy, // but adding kernels to it will fail. // This is used to allow creating operators with empty argument lists // as long as they only have fallback kernels and no dispatched kernels. bool is_valid_; TensorTypeId get_dispatch_key(const Stack* stack, const std::string& operator_name) const { TensorTypeSet ts; for (const auto& ivalue : torch::jit::last(*stack, num_args_)) { if (C10_LIKELY(ivalue.isTensor())) { // NB: Take care not to introduce a refcount bump (there's // no safe toTensorRef method, alas) ts = ts | ivalue.unsafeToTensorImpl()->type_set(); } else if (C10_UNLIKELY(ivalue.isTensorList())) { for (const auto& tensor : ivalue.toTensorListRef()) { ts = ts | tensor.type_set(); } } } // TODO: Don't use legacy extractor; blocked on c10 understanding // variable return c10::legacyExtractTypeId(ts); } }; static DispatchStrategy get_dispatch_strategy_(const FunctionSchema& schema) { bool is_valid = false; for (size_t i = 0; i < schema.arguments().size(); ++i) { const auto& type = schema.arguments()[i].type(); if (type->isSubtypeOf(TensorType::get())) { is_valid = true; break; } if (type->isSubtypeOf(ListType::ofTensors())) { is_valid = true; break; } } return {schema.arguments().size(), is_valid}; } template const KernelFunction& lookup_(const GetDispatchKeyFunc& getDispatchKey) const { c10::optional dispatch_key = getDispatchKey(); if (dispatch_key.has_value()) { const auto* found = kernels_.lookup(*dispatch_key); if (nullptr != found) { return *found; } } if (catchall_kernel_.has_value()) { return *catchall_kernel_; } if (!dispatch_key.has_value() || *dispatch_key == TensorTypeId::UndefinedTensorId) { TORCH_CHECK(false, "There were no tensor arguments to this function (e.g., you passed an " "empty list of Tensors), but no fallback function is registered for schema ", operator_name_, ". This usually means that this function requires a non-empty list of Tensors. " "Available functions are ", listAllDispatchKeys()) } const std::string dispatch_key_str = dispatch_key.has_value() ? toString(*dispatch_key) : "None"; TORCH_CHECK(false, "Didn't find kernel to dispatch to for operator '", operator_name_, "'. Tried to look up kernel for dispatch key '", dispatch_key_str, "'. Registered dispatch keys are: ", listAllDispatchKeys()); } detail::KernelTable_ kernels_; c10::optional catchall_kernel_; DispatchStrategy dispatch_strategy_; std::string operator_name_; }; } // namespace c10