reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-11 bdf3ad71583fb4ef100d3819ecdae8fd9f70083e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#pragma once
 
#include <torch/data/example.h>
#include <torch/data/transforms/base.h>
#include <torch/types.h>
 
#include <functional>
#include <utility>
 
namespace torch {
namespace data {
namespace transforms {
 
/// A `Transform` that is specialized for the typical `Example<Tensor, Tensor>`
/// combination. It exposes a single `operator()` interface hook (for
/// subclasses), and calls this function on input `Example` objects.
template <typename Target = Tensor>
class TensorTransform
    : public Transform<Example<Tensor, Target>, Example<Tensor, Target>> {
 public:
  using E = Example<Tensor, Target>;
  using typename Transform<E, E>::InputType;
  using typename Transform<E, E>::OutputType;
 
  /// Transforms a single input tensor to an output tensor.
  virtual Tensor operator()(Tensor input) = 0;
 
  /// Implementation of `Transform::apply` that calls `operator()`.
  OutputType apply(InputType input) override {
    input.data = (*this)(std::move(input.data));
    return input;
  }
};
 
/// A `Lambda` specialized for the typical `Example<Tensor, Tensor>` input type.
template <typename Target = Tensor>
class TensorLambda : public TensorTransform<Target> {
 public:
  using FunctionType = std::function<Tensor(Tensor)>;
 
  /// Creates a `TensorLambda` from the given `function`.
  explicit TensorLambda(FunctionType function)
      : function_(std::move(function)) {}
 
  /// Applies the user-provided functor to the input tensor.
  Tensor operator()(Tensor input) override {
    return function_(std::move(input));
  }
 
 private:
  FunctionType function_;
};
 
/// Normalizes input tensors by subtracting the supplied mean and dividing by
/// the given standard deviation.
template <typename Target = Tensor>
struct Normalize : public TensorTransform<Target> {
  /// Constructs a `Normalize` transform. The mean and standard deviation can be
  /// anything that is broadcastable over the input tensors (like single
  /// scalars).
  Normalize(ArrayRef<double> mean, ArrayRef<double> stddev)
      : mean(torch::tensor(mean, torch::kFloat32)
                 .unsqueeze(/*dim=*/1)
                 .unsqueeze(/*dim=*/2)),
        stddev(torch::tensor(stddev, torch::kFloat32)
                   .unsqueeze(/*dim=*/1)
                   .unsqueeze(/*dim=*/2)) {}
 
  torch::Tensor operator()(Tensor input) {
    return input.sub(mean).div(stddev);
  }
 
  torch::Tensor mean, stddev;
};
} // namespace transforms
} // namespace data
} // namespace torch