reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-11 bdf3ad71583fb4ef100d3819ecdae8fd9f70083e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#pragma once
 
#include <torch/types.h>
 
#include <utility>
#include <vector>
 
namespace torch {
namespace data {
namespace transforms {
 
/// A transformation of a batch to a new batch.
template <typename InputBatch, typename OutputBatch>
class BatchTransform {
 public:
  using InputBatchType = InputBatch;
  using OutputBatchType = OutputBatch;
 
  virtual ~BatchTransform() = default;
 
  /// Applies the transformation to the given `input_batch`.
  virtual OutputBatch apply_batch(InputBatch input_batch) = 0;
};
 
/// A transformation of individual input examples to individual output examples.
///
/// Just like a `Dataset` is a `BatchDataset`, a `Transform` is a
/// `BatchTransform` that can operate on the level of individual examples rather
/// than entire batches. The batch-level transform is implemented (by default)
/// in terms of the example-level transform, though this can be customized.
template <typename Input, typename Output>
class Transform
    : public BatchTransform<std::vector<Input>, std::vector<Output>> {
 public:
  using InputType = Input;
  using OutputType = Output;
 
  /// Applies the transformation to the given `input`.
  virtual OutputType apply(InputType input) = 0;
 
  /// Applies the `transformation` over the entire `input_batch`.
  std::vector<Output> apply_batch(std::vector<Input> input_batch) override {
    std::vector<Output> output_batch;
    output_batch.reserve(input_batch.size());
    for (auto&& input : input_batch) {
      output_batch.push_back(apply(std::move(input)));
    }
    return output_batch;
  }
};
} // namespace transforms
} // namespace data
} // namespace torch