reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#pragma once
 
#include <mutex>
#include <unordered_map>
#include <unordered_set>
#include "caffe2/core/logging.h"
 
namespace caffe2 {
 
/**
 * thread_local pointer in C++ is a per thread pointer. However, sometimes
 * we want to have a thread local state that is per thread and also per
 * instance. e.g. we have the following class:
 * class A {
 *   ThreadLocalPtr<int> x;
 * }
 * We would like to have a copy of x per thread and also per instance of class A
 * This can be applied to storing per instance thread local state of some class,
 * when we could have multiple instances of the class in the same thread.
 * We implemented a subset of functions in folly::ThreadLocalPtr that's enough
 * to support BlackBoxPredictor.
 */
 
class ThreadLocalPtrImpl;
class ThreadLocalHelper;
 
/**
 * Map of object pointer to instance in each thread
 * to achieve per thread(using thread_local) per object(using the map)
 * thread local pointer
 */
typedef std::unordered_map<ThreadLocalPtrImpl*, std::shared_ptr<void>>
    UnsafeThreadLocalMap;
 
ThreadLocalHelper* getThreadLocalHelper();
 
typedef std::vector<ThreadLocalHelper*> UnsafeAllThreadLocalHelperVector;
 
/**
 * A thread safe vector of all ThreadLocalHelper, this will be used
 * to encapuslate the locking in the APIs for the changes to the global
 * AllThreadLocalHelperVector instance.
 */
class AllThreadLocalHelperVector {
 public:
  AllThreadLocalHelperVector() {}
 
  // Add a new ThreadLocalHelper to the vector
  void push_back(ThreadLocalHelper* helper);
 
  // Erase a ThreadLocalHelper to the vector
  void erase(ThreadLocalHelper* helper);
 
  // Erase object in all the helpers stored in vector
  // Called during destructor of a ThreadLocalPtrImpl
  void erase_tlp(ThreadLocalPtrImpl* ptr);
 
 private:
  UnsafeAllThreadLocalHelperVector vector_;
  std::mutex mutex_;
};
 
/**
 * ThreadLocalHelper is per thread
 */
class ThreadLocalHelper {
 public:
  ThreadLocalHelper();
 
  // When the thread dies, we want to clean up *this*
  // in AllThreadLocalHelperVector
  ~ThreadLocalHelper();
 
  // Insert a (object, ptr) pair into the thread local map
  void insert(ThreadLocalPtrImpl* tl_ptr, std::shared_ptr<void> ptr);
  // Get the ptr by object
  void* get(ThreadLocalPtrImpl* key);
  // Erase the ptr associated with the object in the map
  void erase(ThreadLocalPtrImpl* key);
 
 private:
  // mapping of object -> ptr in each thread
  UnsafeThreadLocalMap mapping_;
  std::mutex mutex_;
}; // ThreadLocalHelper
 
/** ThreadLocalPtrImpl is per object
 */
class ThreadLocalPtrImpl {
 public:
  ThreadLocalPtrImpl() {}
  // Delete copy and move constructors
  ThreadLocalPtrImpl(const ThreadLocalPtrImpl&) = delete;
  ThreadLocalPtrImpl(ThreadLocalPtrImpl&&) = delete;
  ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&) = delete;
  ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&&) = delete;
 
  // In the case when object dies first, we want to
  // clean up the states in all child threads
  ~ThreadLocalPtrImpl();
 
  template <typename T>
  T* get() {
    return static_cast<T*>(getThreadLocalHelper()->get(this));
  }
 
  template <typename T>
  void reset(T* newPtr = nullptr) {
    VLOG(2) << "In Reset(" << newPtr << ")";
    auto* wrapper = getThreadLocalHelper();
    // Cleaning up the objects(T) stored in the ThreadLocalPtrImpl in the thread
    wrapper->erase(this);
    if (newPtr != nullptr) {
      std::shared_ptr<void> sharedPtr(newPtr);
      // Deletion of newPtr is handled by shared_ptr
      // as it implements type erasure
      wrapper->insert(this, std::move(sharedPtr));
    }
  }
 
}; // ThreadLocalPtrImpl
 
template <typename T>
class ThreadLocalPtr {
 public:
  auto* operator-> () {
    return get();
  }
 
  auto& operator*() {
    return *get();
  }
 
  auto* get() {
    return impl_.get<T>();
  }
 
  auto* operator-> () const {
    return get();
  }
 
  auto& operator*() const {
    return *get();
  }
 
  auto* get() const {
    return impl_.get<T>();
  }
 
  void reset(unique_ptr<T> ptr = nullptr) {
    impl_.reset<T>(ptr.release());
  }
 
 private:
  ThreadLocalPtrImpl impl_;
};
 
} // namespace caffe2