reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-10 c3765bd24fe73747688a0ec2a550f219c9acb384
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#pragma once
#include <ATen/Utils.h>
#include <c10/util/ArrayRef.h>
 
#include <vector>
 
namespace at {
  /// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
  /// we can easily view it as a multidimensional array.
  ///
  /// Like ArrayRef, this class does not own the underlying data, it is expected
  /// to be used in situations where the data resides in some other buffer.
  ///
  /// This is intended to be trivially copyable, so it should be passed by
  /// value.
  ///
  /// For now, 2D only (so the copies are actually cheap, without having
  /// to write a SmallVector class) and contiguous only (so we can
  /// return non-strided ArrayRef on index).
  ///
  /// P.S. dimension 0 indexes rows, dimension 1 indexes columns
  template<typename T>
  class MatrixRef {
  public:
    typedef size_t size_type;
 
  private:
    /// Underlying ArrayRef
    ArrayRef<T> arr;
 
    /// Stride of dim 0 (outer dimension)
    size_type stride0;
 
    // Stride of dim 1 is assumed to be 1
 
  public:
    /// Construct an empty Matrixref.
    /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
 
    /// Construct an MatrixRef from an ArrayRef and outer stride.
    /*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0)
      : arr(arr), stride0(stride0) {
        TORCH_CHECK(arr.size() % stride0 == 0, "MatrixRef: ArrayRef size ", arr.size(), " not divisible by stride ", stride0)
      }
 
    /// @}
    /// @name Simple Operations
    /// @{
 
    /// empty - Check if the matrix is empty.
    bool empty() const { return arr.empty(); }
 
    const T *data() const { return arr.data(); }
 
    /// size - Get size a dimension
    size_t size(size_t dim) const {
      if (dim == 0) {
        return arr.size() / stride0;
      } else if (dim == 1) {
        return stride0;
      } else {
        TORCH_CHECK(0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
      }
    }
 
    size_t numel() const {
      return arr.size();
    }
 
    /// equals - Check for element-wise equality.
    bool equals(MatrixRef RHS) const {
      return stride0 == RHS.stride0 && arr.equals(RHS.arr);
    }
 
    /// @}
    /// @name Operator Overloads
    /// @{
    ArrayRef<T> operator[](size_t Index) const {
      return arr.slice(Index*stride0, stride0);
    }
 
    /// Disallow accidental assignment from a temporary.
    ///
    /// The declaration here is extra complicated so that "arrayRef = {}"
    /// continues to select the move assignment operator.
    template <typename U>
    typename std::enable_if<std::is_same<U, T>::value, MatrixRef<T>>::type &
    operator=(U &&Temporary) = delete;
 
    /// Disallow accidental assignment from a temporary.
    ///
    /// The declaration here is extra complicated so that "arrayRef = {}"
    /// continues to select the move assignment operator.
    template <typename U>
    typename std::enable_if<std::is_same<U, T>::value, MatrixRef<T>>::type &
    operator=(std::initializer_list<U>) = delete;
 
  };
 
} // end namespace at