reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-11 bdf3ad71583fb4ef100d3819ecdae8fd9f70083e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#pragma once
 
#include <torch/data/transforms/base.h>
 
#include <functional>
#include <utility>
#include <vector>
 
namespace torch {
namespace data {
namespace transforms {
 
/// A `BatchTransform` that applies a user-provided functor to a batch.
template <typename Input, typename Output = Input>
class BatchLambda : public BatchTransform<Input, Output> {
 public:
  using typename BatchTransform<Input, Output>::InputBatchType;
  using typename BatchTransform<Input, Output>::OutputBatchType;
  using FunctionType = std::function<OutputBatchType(InputBatchType)>;
 
  /// Constructs the `BatchLambda` from the given `function` object.
  explicit BatchLambda(FunctionType function)
      : function_(std::move(function)) {}
 
  /// Applies the user-provided function object to the `input_batch`.
  OutputBatchType apply_batch(InputBatchType input_batch) override {
    return function_(std::move(input_batch));
  }
 
 private:
  FunctionType function_;
};
 
// A `Transform` that applies a user-provided functor to individual examples.
template <typename Input, typename Output = Input>
class Lambda : public Transform<Input, Output> {
 public:
  using typename Transform<Input, Output>::InputType;
  using typename Transform<Input, Output>::OutputType;
  using FunctionType = std::function<Output(Input)>;
 
  /// Constructs the `Lambda` from the given `function` object.
  explicit Lambda(FunctionType function) : function_(std::move(function)) {}
 
  /// Applies the user-provided function object to the `input`.
  OutputType apply(InputType input) override {
    return function_(std::move(input));
  }
 
 private:
  FunctionType function_;
};
 
} // namespace transforms
} // namespace data
} // namespace torch