reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#ifndef CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
#define CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
 
#include <set>
#include <vector>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
 
namespace caffe2 {
 
template <class Context>
class MergeIdListsOp : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  USE_SIMPLE_CTOR_DTOR(MergeIdListsOp);
 
  template <typename T>
  bool DoRunWithType() {
    auto& first_lengths = Input(0);
    CAFFE_ENFORCE_EQ(first_lengths.dim(), 1, "LENGTHS should be 1-D");
    const auto batch_size = first_lengths.numel();
 
    auto* out_lengths = Output(0, first_lengths.sizes(), at::dtype<int32_t>());
 
    auto* out_lengths_data = out_lengths->template mutable_data<int32_t>();
 
    /**
     * Loop to figure out how much space to reserve for output
     * and perform checks.
     */
    auto M = 0;
    for (size_t i = 0; i < InputSize(); i += 2) {
      auto& lengths = Input(i);
      CAFFE_ENFORCE_EQ(lengths.dim(), 1, "LENGTHS should be 1-D");
      CAFFE_ENFORCE_EQ(lengths.numel(), batch_size, "LENGTHS should be equal");
      auto& values = Input(i + 1);
      CAFFE_ENFORCE_EQ(values.dim(), 1, "VALUES should be 1-D");
      M += values.numel();
    }
 
    auto* out_values = Output(1, {M}, at::dtype<T>());
 
    T* out_values_data = out_values->template mutable_data<T>();
    auto pos = 0;
 
    // TODO(badri): Use unordered_set if performance is an issue
    std::set<T> deduped;
    std::vector<int> offsets(InputSize(), 0);
    for (auto sample = 0; sample < batch_size; sample++) {
      for (size_t i = 0; i < InputSize(); i += 2) {
        auto& lengths = Input(i);
        const auto* lengths_data = lengths.template data<int32_t>();
 
        auto& values = Input(i + 1);
        const T* values_data = values.template data<T>();
        const auto length = lengths_data[sample];
 
        for (auto j = offsets[i]; j < offsets[i] + length; j++) {
          deduped.insert(values_data[j]);
        }
        offsets[i] += length;
      }
      for (auto val : deduped) {
        out_values_data[pos++] = val;
      }
      out_lengths_data[sample] = deduped.size();
      deduped.clear();
    }
    out_values->Resize(pos);
    return true;
  }
 
  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(this, Input(1));
  }
};
 
} // namespace caffe2
 
#endif // CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_