#pragma once #include #include #include #include "caffe2/core/logging.h" namespace caffe2 { /** * thread_local pointer in C++ is a per thread pointer. However, sometimes * we want to have a thread local state that is per thread and also per * instance. e.g. we have the following class: * class A { * ThreadLocalPtr x; * } * We would like to have a copy of x per thread and also per instance of class A * This can be applied to storing per instance thread local state of some class, * when we could have multiple instances of the class in the same thread. * We implemented a subset of functions in folly::ThreadLocalPtr that's enough * to support BlackBoxPredictor. */ class ThreadLocalPtrImpl; class ThreadLocalHelper; /** * Map of object pointer to instance in each thread * to achieve per thread(using thread_local) per object(using the map) * thread local pointer */ typedef std::unordered_map> UnsafeThreadLocalMap; ThreadLocalHelper* getThreadLocalHelper(); typedef std::vector UnsafeAllThreadLocalHelperVector; /** * A thread safe vector of all ThreadLocalHelper, this will be used * to encapuslate the locking in the APIs for the changes to the global * AllThreadLocalHelperVector instance. */ class AllThreadLocalHelperVector { public: AllThreadLocalHelperVector() {} // Add a new ThreadLocalHelper to the vector void push_back(ThreadLocalHelper* helper); // Erase a ThreadLocalHelper to the vector void erase(ThreadLocalHelper* helper); // Erase object in all the helpers stored in vector // Called during destructor of a ThreadLocalPtrImpl void erase_tlp(ThreadLocalPtrImpl* ptr); private: UnsafeAllThreadLocalHelperVector vector_; std::mutex mutex_; }; /** * ThreadLocalHelper is per thread */ class ThreadLocalHelper { public: ThreadLocalHelper(); // When the thread dies, we want to clean up *this* // in AllThreadLocalHelperVector ~ThreadLocalHelper(); // Insert a (object, ptr) pair into the thread local map void insert(ThreadLocalPtrImpl* tl_ptr, std::shared_ptr ptr); // Get the ptr by object void* get(ThreadLocalPtrImpl* key); // Erase the ptr associated with the object in the map void erase(ThreadLocalPtrImpl* key); private: // mapping of object -> ptr in each thread UnsafeThreadLocalMap mapping_; std::mutex mutex_; }; // ThreadLocalHelper /** ThreadLocalPtrImpl is per object */ class ThreadLocalPtrImpl { public: ThreadLocalPtrImpl() {} // Delete copy and move constructors ThreadLocalPtrImpl(const ThreadLocalPtrImpl&) = delete; ThreadLocalPtrImpl(ThreadLocalPtrImpl&&) = delete; ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&) = delete; ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&&) = delete; // In the case when object dies first, we want to // clean up the states in all child threads ~ThreadLocalPtrImpl(); template T* get() { return static_cast(getThreadLocalHelper()->get(this)); } template void reset(T* newPtr = nullptr) { VLOG(2) << "In Reset(" << newPtr << ")"; auto* wrapper = getThreadLocalHelper(); // Cleaning up the objects(T) stored in the ThreadLocalPtrImpl in the thread wrapper->erase(this); if (newPtr != nullptr) { std::shared_ptr sharedPtr(newPtr); // Deletion of newPtr is handled by shared_ptr // as it implements type erasure wrapper->insert(this, std::move(sharedPtr)); } } }; // ThreadLocalPtrImpl template class ThreadLocalPtr { public: auto* operator-> () { return get(); } auto& operator*() { return *get(); } auto* get() { return impl_.get(); } auto* operator-> () const { return get(); } auto& operator*() const { return *get(); } auto* get() const { return impl_.get(); } void reset(unique_ptr ptr = nullptr) { impl_.reset(ptr.release()); } private: ThreadLocalPtrImpl impl_; }; } // namespace caffe2