reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-10 c3765bd24fe73747688a0ec2a550f219c9acb384
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#pragma once
 
#include <torch/data/example.h>
#include <torch/data/transforms/collate.h>
#include <torch/types.h>
 
#include <utility>
#include <vector>
 
namespace torch {
namespace data {
namespace transforms {
 
template <typename T = Example<>>
struct Stack;
 
/// A `Collation` for `Example<Tensor, Tensor>` types that stacks all data
/// tensors into one tensor, and all target (label) tensors into one tensor.
template <>
struct Stack<Example<>> : public Collation<Example<>> {
  Example<> apply_batch(std::vector<Example<>> examples) override {
    std::vector<torch::Tensor> data, targets;
    data.reserve(examples.size());
    targets.reserve(examples.size());
    for (auto& example : examples) {
      data.push_back(std::move(example.data));
      targets.push_back(std::move(example.target));
    }
    return {torch::stack(data), torch::stack(targets)};
  }
};
 
/// A `Collation` for `Example<Tensor, NoTarget>` types that stacks all data
/// tensors into one tensor.
template <>
struct Stack<TensorExample>
    : public Collation<Example<Tensor, example::NoTarget>> {
  TensorExample apply_batch(std::vector<TensorExample> examples) override {
    std::vector<torch::Tensor> data;
    data.reserve(examples.size());
    for (auto& example : examples) {
      data.push_back(std::move(example.data));
    }
    return torch::stack(data);
  }
};
} // namespace transforms
} // namespace data
} // namespace torch