reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#pragma once
 
#include <ATen/ATen.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <limits>
 
namespace at {
namespace cuda {
namespace detail {
 
CAFFE2_API bool maybeOverlappingIndices(const at::Tensor& t);
CAFFE2_API bool canUse32BitIndexMath(const at::Tensor &t, int64_t max_elem=std::numeric_limits<int32_t>::max());
 
template <typename scalar, typename IndexType>
TensorInfo<scalar, IndexType>
getTensorInfo(const at::Tensor& t) {
  IndexType sz[MAX_TENSORINFO_DIMS];
  IndexType st[MAX_TENSORINFO_DIMS];
 
  int dims = t.dim();
  for (int i = 0; i < dims; ++i) {
    sz[i] = t.size(i);
    st[i] = t.stride(i);
  }
 
  return TensorInfo<scalar, IndexType>(
    t.data_ptr<scalar>(), dims, sz, st);
}
 
} // detail
} // cuda
} // at