reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-11 bdf3ad71583fb4ef100d3819ecdae8fd9f70083e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#pragma once
 
#include <ATen/core/dispatch/DispatchTable.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <ATen/core/dispatch/RegistrationHandleRAII.h>
#include <list>
 
namespace c10 {
namespace impl {
  class OperatorEntry;
}
 
namespace impl {
 
// This is a private class used inside the Dispatcher to represent an operator
// and its dispatch table. This is not part of the public API.
class OperatorEntry final {
public:
  explicit OperatorEntry(FunctionSchema&& schema, OperatorOptions&& options);
 
  OperatorEntry(const OperatorEntry&) = delete;
  OperatorEntry(OperatorEntry&&) noexcept = delete;
  OperatorEntry& operator=(const OperatorEntry&) = delete;
  OperatorEntry& operator=(OperatorEntry&&) noexcept = delete;
 
  const FunctionSchema& schema() const {
    return schema_;
  }
 
  template<class Return, class... Args>
  Return callUnboxed(TensorTypeId dispatchKey, Args... args) const {
    // TODO Remove dispatchKey argument and instead infer dispatchKey from args...
    #if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && __GNUC__ < 5
      // GCC 4 has issues with parameter packs inside lambdas, let's instead
      // return the KernelFunction from the lambda. Note: This copies the
      // KernelFunction and is slow, but it's the only way to make it work for
      // GCC 4.
      KernelFunction func = dispatchTable_.read([&] (const DispatchTable& dispatchTable) -> KernelFunction {
          return dispatchTable.lookup(dispatchKey);
      });
      return func.callUnboxed<Return, Args...>(std::forward<Args>(args)...);
    #else
      // For all other compilers and newer GCC, let's do it right
      return dispatchTable_.read([&] (const DispatchTable& dispatchTable) -> Return {
          return dispatchTable.lookup(dispatchKey)
              .callUnboxed<Return, Args...>(std::forward<Args>(args)...);
      });
    #endif
  }
 
  template<class Return, class... Args>
  Return callUnboxedOnly(TensorTypeId dispatchKey, Args... args) const {
    // TODO Remove dispatchKey argument and instead infer dispatchKey from args...
    #if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && __GNUC__ < 5
      // GCC 4 has issues with parameter packs inside lambdas, let's instead
      // return the KernelFunction from the lambda. Note: This copies the
      // KernelFunction and is slow, but it's the only way to make it work for
      // GCC 4.
      KernelFunction func = dispatchTable_.read([&] (const DispatchTable& dispatchTable) -> KernelFunction {
          return dispatchTable.lookup(dispatchKey);
      });
      return func.callUnboxedOnly<Return, Args...>(std::forward<Args>(args)...);
    #else
      // For all other compilers and newer GCC, let's do it right
      return dispatchTable_.read([&] (const DispatchTable& dispatchTable) -> Return {
          return dispatchTable.lookup(dispatchKey)
              .callUnboxedOnly<Return, Args...>(std::forward<Args>(args)...);
      });
    #endif
  }
 
  void callBoxed(Stack* stack) const {
    return dispatchTable_.read([&] (const DispatchTable& dispatchTable) {
        dispatchTable.lookup(stack).callBoxed(stack);
    });
  }
 
  void prepareForDeregistration();
 
  RegistrationHandleRAII registerKernel(TensorTypeId dispatch_key, KernelFunction kernel);
  RegistrationHandleRAII registerCatchallKernel(KernelFunction kernel);
 
  const OperatorOptions& options() {
    return options_;
  }
 
private:
  void deregisterKernel_(TensorTypeId dispatch_key, std::list<KernelFunction>::iterator kernel);
  void deregisterCatchallKernel_(std::list<KernelFunction>::iterator kernel);
 
  FunctionSchema schema_;
 
  // The dispatchTable stores the current kernel for each dispatch key
  LeftRight<DispatchTable> dispatchTable_;
 
  // kernels_ stores all registered kernels for the corresponding dispatch key
  // and catchAllKernels_ stores the catch-all kernels.
  // If an operator library gets loaded that overwrites an already existing kernel,
  // both kernels will be in that list but only the newer one will be in
  // dispatchTable. If any of the kernels go away (say the library gets
  // unloaded), we remove the kernel from this list and update the
  // dispatchTable if necessary.
  // Kernels in the list are ordered by registration time descendingly,
  // newer registrations are before older registrations.
  // We do not combine dispatchTable and kernels into one hash map because
  // kernels is a larger data structure and accessed quite infrequently
  // while dispatchTable is accessed often and should be kept small to fit
  // into CPU caches.
  // Invariants:
  //  - dispatchTable[dispatch_key] == kernels_[dispatch_key].front()
  //  - dispatchTable[dispatch_key] does not exist if and only if
  //    kernels_[dispatch_key] does not exist
  //  - If kernels_[dispatch_key] exists, then it has elements.
  //    It is never an empty list.
  // Analogous invariants for catchAllKernels_.
  //
  // Why do we do that?
  // -----
  // We mostly do this to enable Jupyter notebooks where a cell registering
  // a kernel could be executed multiple times and the later execution
  // should overwrite the earlier one. Note that this still fails when the
  // function schema changed between the executions, but it works as long
  // as the function schema didn't change. A better solution would be to
  // unload the old extension library from the Jupyter cell when the cell is
  // re-executed and then only allow one kernel here, i.e. error if a kernel
  // is already registered, but that's a lot of effort to implement and
  // currently not high-pri.
  ska::flat_hash_map<TensorTypeId, std::list<KernelFunction>> kernels_;
  std::list<KernelFunction> catchAllKernels_;
 
  // Some metadata about the operator
  OperatorOptions options_;
 
  std::mutex kernelsMutex_; // protects kernels_
 
  // This function re-establishes the invariant that dispatchTable
  // contains the front element from the kernels list for a given dispatch key.
  void updateDispatchTable_(TensorTypeId dispatch_key);
  void updateCatchallDispatchTable_();
};
 
}
}