#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include // TODO: Rewrite this comment // // This dispatch class serves as a replacement for our previous dispatch // mechanism, in which all functions were members of a Type class. A derived // class existed for each backend (and Variable), and the vtable was used to // dispatch to the correct implementation. This class is to be replaced by // the c10 dispatcher when it supports all argument and return types. // This implementation opts to store implementations in a table of void*. namespace at { namespace impl { // Take a TensorTypeSet for a Tensor, and combine it with the current thread // local valid (implemented) and enabled (not implemented) TensorTypeSets // to determine what the actual dispatch TensorTypeId should be. Unlike // Tensor::type_set(), the value of this on a tensor can change depending // on TLS. // // NB: I didn't make this take a Tensor to avoid header include shenanigans. // // TODO: I'm not sure if this should live in this header or not; the operant // question is whether or not we have access to all the relevant TLS at this // point. static inline TensorTypeId dispatchTypeId(TensorTypeSet ts) { return (ts - c10::impl::tls_excluded_tensor_type_set()).highestPriorityTypeId(); } } namespace detail { struct MultiDispatchTensorTypeSet : IterArgs { TensorTypeSet ts; void operator()(const at::Tensor& x) { ts = ts | x.type_set(); } void operator()(TensorOptions x) { ts = ts | x.type_set(); } void operator()(at::ArrayRef xs) { for (const auto& x : xs) { ts = ts | x.type_set(); } } template void operator()(const T& x) { // do nothing } }; template TensorTypeSet multi_dispatch_tensor_type_set(const Args&... args) { return MultiDispatchTensorTypeSet().apply(args...).ts; } } // ATenOpTable stores the implementations for each backend, in addition to // an implementation for variables. class CAFFE2_API ATenOpTable { public: ATenOpTable(std::string schema) : schema_(std::move(schema)) {} template Result callUnboxed(Args... args) const { using FuncType = Result(Args...); TensorTypeSet ts = detail::multi_dispatch_tensor_type_set(args...); TensorTypeId tid = impl::dispatchTypeId(ts); // You might think we can eliminate the second branch by maintaining a // bitmask of registered operator keys, so we don't select dispatch ids // which don't have implementations here. But the net effect is that if you // get a Variable CPUTensor, if there is no variable registration, you'll // fall back to the CPU implementation. Is this what you want? Unlikely... auto* unboxed_fn = reinterpret_cast(function_table_[static_cast(tid)]); if (C10_LIKELY(unboxed_fn != nullptr)) { return (*unboxed_fn)(args...); } auto* unboxed_fallback_fn = reinterpret_cast(function_table_[static_cast(TensorTypeId::UndefinedTensorId)]); if (C10_LIKELY(unboxed_fallback_fn != nullptr)) { return (*unboxed_fallback_fn)(args...); } reportError(tid); TORCH_INTERNAL_ASSERT(0); } private: void registerOp(TensorTypeId tid, void* fn) { TORCH_CHECK(function_table_[static_cast(tid)] == nullptr, "Attempting to register function for schema ", schema_, " and tensor type ", toString(tid), " but there is already a function registered"); function_table_[static_cast(tid)] = fn; } C10_NORETURN void reportError(TensorTypeId tid) const; friend class ATenDispatch; std::string schema_; void* function_table_[static_cast(TensorTypeId::NumTensorIds)] = {nullptr}; }; class CAFFE2_API ATenDispatch { public: template ATenDispatch& registerOp(TensorTypeId id, const char* schema, FuncType* fn) { std::lock_guard lock(mutex_); if (op_tables_.find(schema) == op_tables_.end()) { op_tables_.insert(std::make_pair(schema, ATenOpTable(schema))); } op_tables_.at(schema).registerOp(id, reinterpret_cast(fn)); return *this; } const ATenOpTable* getOpTable(const char* schema) const { auto iter = op_tables_.find(schema); TORCH_CHECK(iter != op_tables_.end(), "No functions are registered for schema ", schema); return &iter->second; } private: std::unordered_map op_tables_; std::mutex mutex_; }; CAFFE2_API ATenDispatch& globalATenDispatch(); } // namespace at