reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#pragma once
 
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
 
#include <algorithm>
#include <array>
#include <cstdint>
#include <initializer_list>
#include <vector>
 
namespace torch {
 
/// A utility class that accepts either a container of `D`-many values, or a
/// single value, which is internally repeated `D` times. This is useful to
/// represent parameters that are multidimensional, but often equally sized in
/// all dimensions. For example, the kernel size of a 2D convolution has an `x`
/// and `y` length, but `x` and `y` are often equal. In such a case you could
/// just pass `3` to an `ExpandingArray<2>` and it would "expand" to `{3, 3}`.
template <size_t D, typename T = int64_t>
class ExpandingArray {
 public:
  /// Constructs an `ExpandingArray` from an `initializer_list`. The extent of
  /// the length is checked against the `ExpandingArray`'s extent parameter `D`
  /// at runtime.
  /*implicit*/ ExpandingArray(std::initializer_list<T> list)
      : ExpandingArray(at::ArrayRef<T>(list)) {}
 
  /// Constructs an `ExpandingArray` from an `initializer_list`. The extent of
  /// the length is checked against the `ExpandingArray`'s extent parameter `D`
  /// at runtime.
  /*implicit*/ ExpandingArray(at::ArrayRef<T> values) {
    // clang-format off
    TORCH_CHECK(
        values.size() == D,
        "Expected ", D, " values, but instead got ", values.size());
    // clang-format on
    std::copy(values.begin(), values.end(), values_.begin());
  }
 
  /// Constructs an `ExpandingArray` from a single value, which is repeated `D`
  /// times (where `D` is the extent parameter of the `ExpandingArray`).
  /*implicit*/ ExpandingArray(T single_size) {
    values_.fill(single_size);
  }
 
  /// Constructs an `ExpandingArray` from a correctly sized `std::array`.
  /*implicit*/ ExpandingArray(const std::array<T, D>& values)
      : values_(values) {}
 
  /// Accesses the underlying `std::array`.
  std::array<T, D>& operator*() {
    return values_;
  }
 
  /// Accesses the underlying `std::array`.
  const std::array<T, D>& operator*() const {
    return values_;
  }
 
  /// Accesses the underlying `std::array`.
  std::array<T, D>* operator->() {
    return &values_;
  }
 
  /// Accesses the underlying `std::array`.
  const std::array<T, D>* operator->() const {
    return &values_;
  }
 
  /// Returns an `ArrayRef` to the underlying `std::array`.
  operator at::ArrayRef<T>() const {
    return values_;
  }
 
  /// Returns the extent of the `ExpandingArray`.
  size_t size() const noexcept {
    return D;
  }
 
 private:
  /// The backing array.
  std::array<T, D> values_;
};
 
template <size_t D, typename T>
std::ostream& operator<<(
    std::ostream& stream,
    const ExpandingArray<D, T>& expanding_array) {
  if (expanding_array.size() == 1) {
    return stream << expanding_array->at(0);
  }
  return stream << static_cast<at::ArrayRef<T>>(expanding_array);
}
} // namespace torch