#pragma once
|
|
#include <mutex>
|
#include <unordered_map>
|
#include <unordered_set>
|
#include "caffe2/core/logging.h"
|
|
namespace caffe2 {
|
|
/**
|
* thread_local pointer in C++ is a per thread pointer. However, sometimes
|
* we want to have a thread local state that is per thread and also per
|
* instance. e.g. we have the following class:
|
* class A {
|
* ThreadLocalPtr<int> x;
|
* }
|
* We would like to have a copy of x per thread and also per instance of class A
|
* This can be applied to storing per instance thread local state of some class,
|
* when we could have multiple instances of the class in the same thread.
|
* We implemented a subset of functions in folly::ThreadLocalPtr that's enough
|
* to support BlackBoxPredictor.
|
*/
|
|
class ThreadLocalPtrImpl;
|
class ThreadLocalHelper;
|
|
/**
|
* Map of object pointer to instance in each thread
|
* to achieve per thread(using thread_local) per object(using the map)
|
* thread local pointer
|
*/
|
typedef std::unordered_map<ThreadLocalPtrImpl*, std::shared_ptr<void>>
|
UnsafeThreadLocalMap;
|
|
ThreadLocalHelper* getThreadLocalHelper();
|
|
typedef std::vector<ThreadLocalHelper*> UnsafeAllThreadLocalHelperVector;
|
|
/**
|
* A thread safe vector of all ThreadLocalHelper, this will be used
|
* to encapuslate the locking in the APIs for the changes to the global
|
* AllThreadLocalHelperVector instance.
|
*/
|
class AllThreadLocalHelperVector {
|
public:
|
AllThreadLocalHelperVector() {}
|
|
// Add a new ThreadLocalHelper to the vector
|
void push_back(ThreadLocalHelper* helper);
|
|
// Erase a ThreadLocalHelper to the vector
|
void erase(ThreadLocalHelper* helper);
|
|
// Erase object in all the helpers stored in vector
|
// Called during destructor of a ThreadLocalPtrImpl
|
void erase_tlp(ThreadLocalPtrImpl* ptr);
|
|
private:
|
UnsafeAllThreadLocalHelperVector vector_;
|
std::mutex mutex_;
|
};
|
|
/**
|
* ThreadLocalHelper is per thread
|
*/
|
class ThreadLocalHelper {
|
public:
|
ThreadLocalHelper();
|
|
// When the thread dies, we want to clean up *this*
|
// in AllThreadLocalHelperVector
|
~ThreadLocalHelper();
|
|
// Insert a (object, ptr) pair into the thread local map
|
void insert(ThreadLocalPtrImpl* tl_ptr, std::shared_ptr<void> ptr);
|
// Get the ptr by object
|
void* get(ThreadLocalPtrImpl* key);
|
// Erase the ptr associated with the object in the map
|
void erase(ThreadLocalPtrImpl* key);
|
|
private:
|
// mapping of object -> ptr in each thread
|
UnsafeThreadLocalMap mapping_;
|
std::mutex mutex_;
|
}; // ThreadLocalHelper
|
|
/** ThreadLocalPtrImpl is per object
|
*/
|
class ThreadLocalPtrImpl {
|
public:
|
ThreadLocalPtrImpl() {}
|
// Delete copy and move constructors
|
ThreadLocalPtrImpl(const ThreadLocalPtrImpl&) = delete;
|
ThreadLocalPtrImpl(ThreadLocalPtrImpl&&) = delete;
|
ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&) = delete;
|
ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&&) = delete;
|
|
// In the case when object dies first, we want to
|
// clean up the states in all child threads
|
~ThreadLocalPtrImpl();
|
|
template <typename T>
|
T* get() {
|
return static_cast<T*>(getThreadLocalHelper()->get(this));
|
}
|
|
template <typename T>
|
void reset(T* newPtr = nullptr) {
|
VLOG(2) << "In Reset(" << newPtr << ")";
|
auto* wrapper = getThreadLocalHelper();
|
// Cleaning up the objects(T) stored in the ThreadLocalPtrImpl in the thread
|
wrapper->erase(this);
|
if (newPtr != nullptr) {
|
std::shared_ptr<void> sharedPtr(newPtr);
|
// Deletion of newPtr is handled by shared_ptr
|
// as it implements type erasure
|
wrapper->insert(this, std::move(sharedPtr));
|
}
|
}
|
|
}; // ThreadLocalPtrImpl
|
|
template <typename T>
|
class ThreadLocalPtr {
|
public:
|
auto* operator-> () {
|
return get();
|
}
|
|
auto& operator*() {
|
return *get();
|
}
|
|
auto* get() {
|
return impl_.get<T>();
|
}
|
|
auto* operator-> () const {
|
return get();
|
}
|
|
auto& operator*() const {
|
return *get();
|
}
|
|
auto* get() const {
|
return impl_.get<T>();
|
}
|
|
void reset(unique_ptr<T> ptr = nullptr) {
|
impl_.reset<T>(ptr.release());
|
}
|
|
private:
|
ThreadLocalPtrImpl impl_;
|
};
|
|
} // namespace caffe2
|