reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-11 bdf3ad71583fb4ef100d3819ecdae8fd9f70083e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
#pragma once
 
#include <ATen/core/function_schema.h>
#include <c10/util/LeftRight.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/either.h>
#include <c10/core/TensorTypeId.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/ATenDispatch.h>
 
#include <array>
#include <atomic>
#include <iostream>
#include <mutex>
#include <type_traits>
#include <sstream>
#include <unordered_map>
#include <functional>
 
namespace c10 {
 
namespace detail {
 
class KernelTable_ final {
 public:
  void set(TensorTypeId key, const KernelFunction& value, const std::string& operator_name) {
    auto emplaced = map_.emplace(key, value);
    if (!emplaced.second) {
      // Element already existed. Overwrite it.
      emplaced.first->second = value;
      TORCH_WARN("Registered a kernel for operator ", operator_name," with dispatch key ", toString(key), " that overwrote a previously registered kernel with the same dispatch key for the same operator.");
    }
  }
 
  void removeIfExists(TensorTypeId key, const std::string& operator_name) {
    auto num_removed = map_.erase(key);
    TORCH_INTERNAL_ASSERT(num_removed <= 1); // This is not a multi-map
  }
 
  const KernelFunction* lookup(TensorTypeId key) const {
    auto found = map_.find(key);
    if (found != map_.end()) {
      return &found->second;
    } else {
      return nullptr;
    }
  }
 
  size_t size() const {
    return map_.size();
  }
 
  std::string list_all_dispatch_keys() const {
    if (map_.size() == 0) {
      return "[]";
    }
    std::ostringstream str;
    str << "[" << toString(map_.begin()->first);
    for (auto iter = ++map_.begin(); iter != map_.end(); ++iter) {
      str << ", " << toString(iter->first);
    }
    str << "]";
    return str.str();
  }
 
 private:
   ska::flat_hash_map<TensorTypeId, KernelFunction> map_;
};
} // namespace detail
 
/**
 * Per-operator dispatch table.
 *
 * Given an operator specified by a FunctionSchema, this class records a dispatch
 * table for various kernels provided for this operator.  For example, if we
 * consider the operator add(Tensor, Tensor), the dispatch table for this
 * operator may contain implementations for various dynamic tensor types, such
 * as CPUTensorId, CUDATensorId, etc.
 */
class DispatchTable final {
 public:
  DispatchTable(const FunctionSchema& schema)
  : kernels_()
  , catchall_kernel_(c10::nullopt)
  , dispatch_strategy_(get_dispatch_strategy_(schema))
  , operator_name_(schema.name()) {}
 
  /**
   * Register a kernel in the table at some dispatch key.
   * @param dispatch_key Dispatch key to define when this kernel is selected.
   * @param kernel Concrete kernel function implementation to register
   */
  void setKernel(
      TensorTypeId dispatch_key,
      const KernelFunction& kernel) {
    TORCH_INTERNAL_ASSERT(dispatch_key != TensorTypeId::UndefinedTensorId);
    // The following assertion is disabled because we're codegenerating
    // autograd kernels for operators without tensor arguments even though
    // they are never called. These, however, register kernels for
    // VariableTensorId.
    // TODO Stop generating those kernels and re-enable this assertion here.
    //TORCH_CHECK(dispatch_strategy_.is_valid_, "Tried to register a kernel with dispatch key ", toString(dispatch_key), " for operator ", operator_name_, " that doesn't have tensor arguments.");
    kernels_.set(dispatch_key, kernel, operator_name_);
  }
 
  /**
   * Deregister the kernel for some dispatch key.
   *
   * @param dispatch_key Dispatch key to unregister.
   */
  void removeKernelIfExists(TensorTypeId dispatch_key) {
    kernels_.removeIfExists(dispatch_key, operator_name_);
  }
 
  /**
   * Register a catch-all kernel that is called for this operator
   * independent of the inputs. An operator can have either
   * a catch-all kernel or a set of kernels with concrete
   * dispatch keys, not both.
   */
  void setCatchallKernel(const KernelFunction& kernel) {
    if (catchall_kernel_.has_value()) {
      TORCH_WARN("Registered a catch-all kernel for operator ", operator_name_," that overwrote a previously registered catch-all kernel for the same operator.");
    }
    catchall_kernel_ = kernel;
  }
 
  /**
   * Remove the catch-all kernel.
   */
  void removeCatchallKernel() {
    TORCH_INTERNAL_ASSERT(catchall_kernel_.has_value(), "Tried to remove the catch-all kernel for operator ", operator_name_," but there is no catch-all kernel registered.");
    catchall_kernel_ = c10::nullopt;
  }
 
  /**
   * Perform a dynamic dispatch on this table and find the kernel to call
   * for the given arguments.
   *
   * @param args Arguments to invoke the function with
   * @return Kernel function pointing to the right kernel for the given arguments.
   */
   const KernelFunction& lookup(const Stack* stack) const {
     return lookup_([=] () -> c10::optional<TensorTypeId> {
       if (!dispatch_strategy_.is_valid_) {
         return c10::nullopt;
       }
       return dispatch_strategy_.get_dispatch_key(stack, operator_name_);
     });
   }
 
   const KernelFunction& lookup(TensorTypeId dispatchKey) const {
     return lookup_([=] () -> c10::optional<TensorTypeId> { return dispatchKey;});
   }
 
   bool isEmpty() const {
     return !catchall_kernel_.has_value() && kernels_.size() == 0;
   }
 
   std::string listAllDispatchKeys() const {
     std::string result = kernels_.list_all_dispatch_keys();
     if (catchall_kernel_.has_value()) {
       result += ", CATCH-ALL";
     }
     return result;
   }
 
private:
  struct DispatchStrategy final {
    // this is caching the index so we don't have to parse the schema inputs
    // again and again for each dispatcher lookup.
    // num_args_ is allowed to be zero; that just means you must do the
    // fallthrough
    // TODO: a potential optimization is to store a bitfield of arg locations,
    size_t num_args_;
 
    // An invalid dispatch strategy means we can't dispatch any kernels.
    // You're able to create a dispatch table with an invalid dispatch strategy,
    // but adding kernels to it will fail.
    // This is used to allow creating operators with empty argument lists
    // as long as they only have fallback kernels and no dispatched kernels.
    bool is_valid_;
 
    TensorTypeId get_dispatch_key(const Stack* stack, const std::string& operator_name) const {
 
      TensorTypeSet ts;
      for (const auto& ivalue : torch::jit::last(*stack, num_args_)) {
        if (C10_LIKELY(ivalue.isTensor())) {
          // NB: Take care not to introduce a refcount bump (there's
          // no safe toTensorRef method, alas)
          ts = ts | ivalue.unsafeToTensorImpl()->type_set();
        } else if (C10_UNLIKELY(ivalue.isTensorList())) {
          for (const auto& tensor : ivalue.toTensorListRef()) {
            ts = ts | tensor.type_set();
          }
        }
      }
      // TODO: Don't use legacy extractor; blocked on c10 understanding
      // variable
      return c10::legacyExtractTypeId(ts);
    }
  };
 
  static DispatchStrategy get_dispatch_strategy_(const FunctionSchema& schema) {
    bool is_valid = false;
    for (size_t i = 0; i < schema.arguments().size(); ++i) {
      const auto& type = schema.arguments()[i].type();
      if (type->isSubtypeOf(TensorType::get())) {
        is_valid = true;
        break;
      }
      if (type->isSubtypeOf(ListType::ofTensors())) {
        is_valid = true;
        break;
      }
    }
 
    return {schema.arguments().size(), is_valid};
  }
 
  template<class GetDispatchKeyFunc>
  const KernelFunction& lookup_(const GetDispatchKeyFunc& getDispatchKey) const {
      c10::optional<TensorTypeId> dispatch_key = getDispatchKey();
      if (dispatch_key.has_value()) {
        const auto* found = kernels_.lookup(*dispatch_key);
 
        if (nullptr != found) {
          return *found;
        }
      }
 
      if (catchall_kernel_.has_value()) {
        return *catchall_kernel_;
      }
 
      if (!dispatch_key.has_value() || *dispatch_key == TensorTypeId::UndefinedTensorId) {
        TORCH_CHECK(false,
              "There were no tensor arguments to this function (e.g., you passed an "
              "empty list of Tensors), but no fallback function is registered for schema ", operator_name_,
              ".  This usually means that this function requires a non-empty list of Tensors.  "
              "Available functions are ", listAllDispatchKeys())
      }
 
      const std::string dispatch_key_str = dispatch_key.has_value() ? toString(*dispatch_key) : "None";
      TORCH_CHECK(false, "Didn't find kernel to dispatch to for operator '", operator_name_,
               "'. Tried to look up kernel for dispatch key '", dispatch_key_str,
               "'. Registered dispatch keys are: ", listAllDispatchKeys());
  }
 
  detail::KernelTable_ kernels_;
  c10::optional<KernelFunction> catchall_kernel_;
  DispatchStrategy dispatch_strategy_;
  std::string operator_name_;
};
 
} // namespace c10