reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 1bb4d137919cae4f57f95a2572ee612dcabb3b3d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#pragma once
 
#include <ATen/core/ivalue.h>
#include <c10/util/SmallVector.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
 
namespace torch { namespace autograd {
 
struct Node;
 
namespace profiler {
 
struct TORCH_API StringView {
  StringView() : StringView(nullptr) {}
  explicit StringView(const char* str_ptr)
    : owned_str_ptr_(nullptr), str_ptr_(str_ptr) {}
  explicit StringView(std::string str)
    : owned_str_ptr_(std::make_shared<std::string>(std::move(str))),
      str_ptr_(owned_str_ptr_->c_str()) {}
 
  inline const char* str() const {
    return str_ptr_;
  }
 private:
  std::shared_ptr<std::string> owned_str_ptr_;
  const char* str_ptr_;
};
 
struct TORCH_API RecordFunction {
  // Default constructor is used with before function called afterwards
  RecordFunction() {}
 
  // before function initializes RecordFunction members and calls
  // start callbacks
  void before(const char* name, int64_t sequence_nr = -1);
  void before(std::string name, int64_t sequence_nr = -1);
  void before(Node* fn, int64_t sequence_nr = -1);
 
  template<typename F>
  void before(
      F fn,
      c10::ArrayRef<c10::IValue> args,
      int64_t current_sequence_nr = -1) {
    inputs_ = args.vec();
    before(fn, current_sequence_nr);
  }
 
  template<typename F>
  void before(
      F fn,
      std::vector<c10::IValue>&& args,
      int64_t current_sequence_nr = -1) {
    inputs_ = std::move(args);
    before(fn, current_sequence_nr);
  }
 
  // Destructor calls end callbacks
  virtual ~RecordFunction();
 
  inline Node* func() const {
    return fn_;
  }
 
  inline const StringView& name() const {
    return name_;
  }
 
  inline int64_t seqNr() const {
    return sequence_nr_;
  }
 
  const std::vector<c10::IValue>& inputs() const {
    return inputs_;
  }
 
  inline const RecordFunction* parent() const {
    return parent_;
  }
 
  void setRunSampled(bool run_sampled) {
    run_sampled_ = run_sampled;
  }
 
 private:
  void processCallbacks();
 
  Node* fn_ = nullptr;
  StringView name_;
  int64_t sequence_nr_ = -1;
  std::vector<c10::IValue> inputs_;
  RecordFunction* parent_ = nullptr;
 
  bool initialized_ = false;
  bool run_sampled_ = false;
};
 
TORCH_API bool hasCallbacks();
TORCH_API bool needsInputs();
TORCH_API bool hasNonSampledCallbacks();
 
TORCH_API void setSamplingProbability(double);
TORCH_API double getSamplingProbability();
 
TORCH_API bool shouldRunSampledCallbacks();
 
// optional argument - function's seq_no
#define RECORD_FUNCTION(fn, inputs, ...) \
  torch::autograd::profiler::RecordFunction guard; \
  if (torch::autograd::profiler::hasCallbacks()) { \
    auto run_sampled = torch::autograd::profiler::shouldRunSampledCallbacks(); \
    if (run_sampled || torch::autograd::profiler::hasNonSampledCallbacks()) { \
      guard.setRunSampled(run_sampled); \
      if (torch::autograd::profiler::needsInputs()) { \
        guard.before(fn, inputs, ##__VA_ARGS__); \
      } else { \
        guard.before(fn, ##__VA_ARGS__); \
      } \
    } \
  }
 
// WARNING: all calls to pushCallback/popCallback are not thread safe and
// must not overlap with other code execution
using RecordFunctionCallback = std::function<void(const RecordFunction&)>;
TORCH_API void pushCallback(
    RecordFunctionCallback start,
    RecordFunctionCallback end = [](const RecordFunction&){},
    bool needs_inputs = false,
    bool sampled = false);
TORCH_API void popCallback();
 
} // namespace profiler
}} // namespace torch::autograd