reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#ifndef CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_
#define CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_
 
#include <math.h>
 
#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/types.h"
#include "caffe2/utils/math.h"
 
#include <cstring>
#include <map>
#include <utility>
 
namespace caffe2 {
template <class Context>
class GatherRangesToDenseOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit GatherRangesToDenseOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        lengths_(this->template GetRepeatedArgument<int>("lengths")) {
    CAFFE_ENFORCE_GT(lengths_.size(), 0, "There has to be at least one length");
    for (auto length : lengths_) {
      CAFFE_ENFORCE_GT(length, 0, "Each length should be positive");
    }
  }
 
  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
        this, this->template Input<Tensor>(RANGES, CPU));
  }
 
  template <typename Index>
  bool DoRunWithType() {
    auto& data = Input(DATA);
    auto& ranges = Input(RANGES);
    CAFFE_ENFORCE_EQ(data.dim(), 1, "Data has to be 1-D");
    CAFFE_ENFORCE_EQ(ranges.dim(), 3, "Ranges has to be 3-D");
    if (InputSize() == 3) {
      auto& key = Input(KEY);
      CAFFE_ENFORCE_EQ(key.dim(), 1, "Key has to be 1-D");
      CAFFE_ENFORCE(
          key.dtype().template Match<int64_t>(), "Key has to be type int64_t");
    }
    CAFFE_ENFORCE_EQ(
        ranges.size(1),
        lengths_.size(),
        "Nummber of ranges should match number of lengths");
    CAFFE_ENFORCE_EQ(
        ranges.size(1),
        OutputSize(),
        "Nummber of ranges should match number of outputs");
    CAFFE_ENFORCE_EQ(
        ranges.size(2), 2, "Ranges last dimension should be of size 2");
 
    auto* rawData = static_cast<const char*>(data.raw_data());
    auto* rangesData = ranges.template data<Index>();
    int rangesDataOffset = 0;
    auto itemsize = data.dtype().itemsize();
 
    auto batchSize = ranges.size(0);
    vector<int64_t> outputDims{batchSize, 0};
    vector<char*> outputRawData;
    for (int i = 0; i < OutputSize(); ++i) {
      auto* output = Output(i);
      outputDims[1] = lengths_[i];
      output->Resize(outputDims);
      char* ptr = static_cast<char*>(output->raw_mutable_data(data.dtype()));
      memset(ptr, 0, output->nbytes());
      outputRawData.push_back(ptr);
    }
 
    for (int i = 0; i < batchSize; ++i) {
      for (int j = 0; j < OutputSize(); ++j) {
        auto rangeStart = rangesData[rangesDataOffset++];
        auto rangeLength = rangesData[rangesDataOffset++];
        if (rangeLength == 0) {
          // empty range, will be filled with zeros
          continue;
        }
        CAFFE_ENFORCE_EQ(
            rangeLength,
            lengths_[j],
            "Range lengths missmatch for output #",
            j);
 
        if (InputSize() == 2) {
          context_.CopyItemsSameDevice(
              data.dtype(),
              rangeLength,
              rawData + rangeStart * itemsize,
              outputRawData[j] + i * itemsize * lengths_[j]);
        } else {
          auto& key = Input(KEY);
          auto* key_data = key.template data<int64_t>();
          vector<std::pair<int64_t, const char*>> buffer;
          for (int b_i = 0; b_i < rangeLength; ++b_i) {
            int64_t one_key_item = key_data[rangeStart + b_i];
            auto* one_data_item = rawData + (rangeStart + b_i) * itemsize;
            buffer.emplace_back(one_key_item, one_data_item);
          }
          std::sort(
              buffer.begin(),
              buffer.end(),
              [](const std::pair<int64_t, const char*>& left,
                 const std::pair<int64_t, const char*>& right) {
                return left.first < right.first;
              });
          for (int b_i = 0; b_i < rangeLength; ++b_i) {
            // Since this CPU only, directly copy to the destination.
            std::memcpy(
                outputRawData[j] + (i * lengths_[j] + b_i) * itemsize,
                buffer[b_i].second,
                itemsize);
          }
        }
      }
    }
    CAFFE_ENFORCE_EQ(rangesDataOffset, ranges.numel());
 
    return true;
  }
 
  INPUT_TAGS(DATA, RANGES, KEY);
 
 private:
  vector<int> lengths_;
};
 
} // namespace caffe2
 
#endif // CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_