reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
#pragma once
 
#include <c10/util/ArrayRef.h>
#include <c10/util/Half.h>
#include <c10/util/Optional.h>
#include <c10/util/typeid.h>
 
#include <complex>
#include <cstdint>
#include <iostream>
 
namespace c10 {
 
// For the macros below:
// NB: If you want to macro some code for all non-QInt scalar types (i.e. types
// with complete information, you probably want one of the
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND
// macros below, which are designed to behave similarly to the Dispatch macros
// with the same name.
 
// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
  _(uint8_t, Byte) /* 0 */                               \
  _(int8_t, Char) /* 1 */                                \
  _(int16_t, Short) /* 2 */                              \
  _(int, Int) /* 3 */                                    \
  _(int64_t, Long) /* 4 */                               \
  _(at::Half, Half) /* 5 */                              \
  _(float, Float) /* 6 */                                \
  _(double, Double) /* 7 */                              \
  _(at::ComplexHalf, ComplexHalf) /* 8 */                \
  _(std::complex<float>, ComplexFloat) /* 9 */           \
  _(std::complex<double>, ComplexDouble) /* 10 */        \
  _(bool, Bool) /* 11 */                                 \
  _(c10::qint8, QInt8) /* 12 */                          \
  _(c10::quint8, QUInt8) /* 13 */                        \
  _(c10::qint32, QInt32) /* 14 */                        \
  _(at::BFloat16, BFloat16) /* 15 */
 
 
// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name).  But beware: convert()
// doesn't work for all the conversions you need...
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \
  _(uint8_t, Byte)                                                 \
  _(int8_t, Char)                                                  \
  _(int16_t, Short)                                                \
  _(int, Int)                                                      \
  _(int64_t, Long)                                                 \
  _(at::Half, Half)                                                \
  _(float, Float)                                                  \
  _(double, Double)                                                \
  _(std::complex<float>, ComplexFloat)                             \
  _(std::complex<double>, ComplexDouble)                           \
  _(bool, Bool)                                                    \
  _(at::BFloat16, BFloat16)
 
 
enum class ScalarType : int8_t {
#define DEFINE_ENUM(_1, n) n,
  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ENUM)
#undef DEFINE_ENUM
      Undefined,
  NumOptions
};
 
namespace impl {
 
// These are used to map ScalarTypes to C++ types.  Feel free to add more or even
// macro generate this; the examples here are just those we have found to be
// necessary.
 
template <c10::ScalarType N>
struct ScalarTypeToCPPType;
 
template<>
struct ScalarTypeToCPPType<c10::ScalarType::Half> {
  using type = c10::Half;
 
  // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
  // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
  // For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba
  // TODO: remove once the bug is fixed.
  static type t;
};
 
template<>
struct ScalarTypeToCPPType<c10::ScalarType::BFloat16> {
  using type = c10::BFloat16;
 
  // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
  // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
  // For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba
  // TODO: remove once the bug is fixed.
  static type t;
};
 
template<>
struct ScalarTypeToCPPType<c10::ScalarType::Bool> {
  using type = bool;
 
  // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
  // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
  // For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba
  // TODO: remove once the bug is fixed.
  static type t;
};
 
template<>
struct ScalarTypeToCPPType<c10::ScalarType::Long> {
  using type = int64_t;
 
  // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
  // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
  // For repro example, please see: https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba
  // TODO: remove once the bug is fixed.
  static type t;
};
}
 
#define AT_FORALL_SCALAR_TYPES(_) \
  _(uint8_t, Byte)                \
  _(int8_t, Char)                 \
  _(int16_t, Short)               \
  _(int, Int)                     \
  _(int64_t, Long)                \
  _(float, Float)                 \
  _(double, Double)
 
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _)                          \
  _(uint8_t, Byte)                                                         \
  _(int8_t, Char)                                                          \
  _(int16_t, Short)                                                        \
  _(int, Int)                                                              \
  _(int64_t, Long)                                                         \
  _(float, Float)                                                          \
  _(double, Double)                                                        \
  _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE>::t), SCALARTYPE)
 
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _)                                \
  _(uint8_t, Byte)                                                                              \
  _(int8_t, Char)                                                                               \
  _(int16_t, Short)                                                                             \
  _(int, Int)                                                                                   \
  _(int64_t, Long)                                                                              \
  _(float, Float)                                                                               \
  _(double, Double)                                                                             \
  _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), SCALARTYPE1) \
  _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), SCALARTYPE2)
 
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _)                   \
  _(uint8_t, Byte)                                                                              \
  _(int8_t, Char)                                                                               \
  _(int16_t, Short)                                                                             \
  _(int, Int)                                                                                   \
  _(int64_t, Long)                                                                              \
  _(float, Float)                                                                               \
  _(double, Double)                                                                             \
  _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), SCALARTYPE1) \
  _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), SCALARTYPE2) \
  _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE3>::t), SCALARTYPE3)
 
#define AT_FORALL_QINT_TYPES(_)  \
  _(c10::qint8, QInt8)           \
  _(c10::quint8, QUInt8)         \
  _(c10::qint32, QInt32)
 
static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
#define DEFINE_CASE(ctype, name) \
  case ScalarType::name:         \
    return caffe2::TypeMeta::Make<ctype>();
 
  switch (scalar_type) {
    AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
    case ScalarType::Undefined:
      return caffe2::TypeMeta();
    default:
      AT_ERROR(
          "Unrecognized Scalartype ",
          scalar_type,
          " (please report this error)");
  }
#undef DEFINE_CASE
}
 
static inline c10::optional<ScalarType> tryTypeMetaToScalarType(
    caffe2::TypeMeta dtype) {
#define DEFINE_IF(ctype, name)                    \
  if (dtype == caffe2::TypeMeta::Make<ctype>()) { \
    return {ScalarType::name};                    \
  }
  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_IF)
#undef DEFINE_IF
  if (dtype == caffe2::TypeMeta()) {
    return {ScalarType::Undefined};
  }
  return c10::nullopt;
}
 
static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
  if (auto scalar_type = tryTypeMetaToScalarType(dtype)) {
    return *scalar_type;
  }
  AT_ERROR(
      "Unsupported TypeMeta in ATen: ", dtype, " (please report this error)");
}
 
static inline bool operator==(ScalarType t, caffe2::TypeMeta m) {
  if (auto mt = tryTypeMetaToScalarType(m)) {
    return (*mt) == t;
  }
  return false;
}
 
static inline bool operator==(caffe2::TypeMeta m, ScalarType t) {
  return t == m;
}
 
#define DEFINE_CONSTANT(_, name) \
  constexpr ScalarType k##name = ScalarType::name;
 
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
#undef DEFINE_CONSTANT
 
static inline const char* toString(ScalarType t) {
#define DEFINE_CASE(_, name) \
  case ScalarType::name:     \
    return #name;
 
  switch (t) {
    AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
    default:
      return "UNKNOWN_SCALAR";
  }
#undef DEFINE_CASE
}
 
static inline size_t elementSize(ScalarType t) {
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
  case ScalarType::name:                   \
    return sizeof(ctype);
 
  switch (t) {
    AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
    default:
      AT_ERROR("Unknown ScalarType");
  }
#undef CASE_ELEMENTSIZE_CASE
}
 
C10_DEPRECATED_MESSAGE("isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.")
static inline bool isIntegralType(ScalarType t) {
  return (
      t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
      t == ScalarType::Long || t == ScalarType::Short);
}
 
static inline bool isIntegralType(ScalarType t, bool includeBool) {
  bool isIntegral = (
      t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
      t == ScalarType::Long || t == ScalarType::Short);
 
  return includeBool ? isIntegral || (t == ScalarType::Bool) : isIntegral;
}
 
static inline bool isFloatingType(ScalarType t) {
  return (
      t == ScalarType::Double || t == ScalarType::Float ||
      t == ScalarType::Half || t == ScalarType::BFloat16);
}
 
static inline bool isComplexType(ScalarType t) {
  return (
      t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat ||
      t == ScalarType::ComplexDouble);
}
 
static inline bool isQIntType(ScalarType t) {
  // Don't forget to extend this when adding new QInt types
  return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32;
}
 
static inline ScalarType toQIntType(ScalarType t) {
  switch (t) {
    case ScalarType::Byte:
      return ScalarType::QUInt8;
    case ScalarType::Char:
      return ScalarType::QInt8;
    case ScalarType::Int:
      return ScalarType::QInt32;
    default:
      return t;
  }
}
 
static inline ScalarType toUnderlying(ScalarType t) {
  switch (t) {
    case ScalarType::QUInt8:
      return ScalarType::Byte;
    case ScalarType::QInt8:
      return ScalarType::Char;
    case ScalarType::QInt32:
      return ScalarType::Int;
    default:
      return t;
  }
}
 
static inline bool isSignedType(ScalarType t) {
  #define CASE_SIGNED(ctype, name) \
    case ScalarType::name:                       \
      return std::numeric_limits<ctype>::is_signed;
 
    switch (t) {
      AT_FORALL_SCALAR_TYPES_AND(Half, CASE_SIGNED)
      default:
        AT_ERROR("Unknown ScalarType");
    }
  #undef CASE_SIGNED
}
 
static inline bool isUnderlying(ScalarType type, ScalarType qtype) {
  return type == toUnderlying(qtype);
}
 
// see tensor_attributes.rst for detailed explanation and examples
// of casting rules.
static inline bool canCast(const ScalarType from, const ScalarType to) {
  // We disallow float -> integral, e.g., int_tensor *= float is disallowed.
  if (isFloatingType(from) && isIntegralType(to, false)) {
    return false;
  }
 
  // Treat bool as a distinct "category," to be consistent with type promotion
  // rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same category
  // as `bool_tensor`, we would not promote.
  // Differing categories implies `bool_tensor += 5` is disallowed.
  //
  // NB: numpy distinguishes "unsigned" as a category to get the desired
  // `bool_tensor + 5 -> int64_tensor` behavior. We don't, because:
  // * We don't want the performance hit of checking the runtime sign of Scalars.
  // * `uint8_tensor + 5 -> int64_tensor` would be undesirable.
  if (from != ScalarType::Bool && to == ScalarType::Bool) {
    return false;
  }
  return true;
}
 
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
  // This is generated according to NumPy's promote_types
  constexpr auto u1 = ScalarType::Byte;
  constexpr auto i1 = ScalarType::Char;
  constexpr auto i2 = ScalarType::Short;
  constexpr auto i4 = ScalarType::Int;
  constexpr auto i8 = ScalarType::Long;
  constexpr auto f2 = ScalarType::Half;
  constexpr auto f4 = ScalarType::Float;
  constexpr auto f8 = ScalarType::Double;
  constexpr auto c4 = ScalarType::ComplexFloat;
  constexpr auto c8 = ScalarType::ComplexDouble;
  constexpr auto b1 = ScalarType::Bool;
  constexpr auto bf = ScalarType::BFloat16;
  constexpr auto ud = ScalarType::Undefined;
  if (a == ud || b == ud) {
    return ScalarType::Undefined;
  }
  if (isComplexType(a) || isComplexType(b)) {
    AT_ERROR(
        "promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be for ", toString(a), " and ", toString(b));
  }
 
  // For QInt types, we only allow exact match
  if (isQIntType(a) && a == b) {
    return a;
  }
 
  if (isQIntType(a) || isQIntType(b)) {
    AT_ERROR(
        "promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: ",
        toString(a),
        " ",
        toString(b));
  }
 
  // this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX
  // so that's why we have to add undefined as we are not sure what is the
  // corrent values for the type promotions in complex type cases.
  static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
      ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
        /*        u1  i1  i2  i4  i8  f2  f4  f8  c2  c4  c8  b1  q1  q2  q3  bf*/
        /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, u1, ud, ud, ud, ud},
        /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, i1, ud, ud, ud, ud},
        /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, i2, ud, ud, ud, ud},
        /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, ud, c4, c8, i4, ud, ud, ud, ud},
        /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, ud, c4, c8, i8, ud, ud, ud, ud},
        /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, ud, c4, c8, f2, ud, ud, ud, ud},
        /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, ud, c4, c8, f4, ud, ud, ud, ud},
        /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, ud, c8, c8, f8, ud, ud, ud, ud},
        /* c2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud},
        /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, ud, ud, ud, ud, ud},
        /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, ud, ud},
        /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, b1, ud, ud, ud, ud},
        /* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
        /* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
        /* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
        /* bf */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf},
  };
  return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
}
 
inline std::ostream& operator<<(
    std::ostream& stream,
    at::ScalarType scalar_type) {
  return stream << toString(scalar_type);
}
 
} // namespace c10