reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#pragma once
 
#include <torch/csrc/jit/operator.h>
#include <ATen/core/stack.h>
#include <ATen/core/op_registration/op_registration.h>
 
namespace torch {
namespace jit {
 
/// Registration class for new operators. Effectively calls
/// `torch::jit::registerOperator` for every supplied operator, but allows doing
/// so in the global scope when a `RegisterOperators` object is assigned to a
/// static variable. Also handles registration of user-defined, "custom"
/// operators.
struct TORCH_API RegisterOperators {
  RegisterOperators() = default;
 
  /// Registers a vector of already created `Operator`s.
  RegisterOperators(std::vector<Operator> operators) {
    for (Operator& o : operators) {
      registerOperator(std::move(o));
    }
  }
 
  /// Calls `op(...)` with the given operator name and implementation.
  template <typename Implementation>
  C10_DEPRECATED_MESSAGE("torch::jit::RegisterOperators is deprecated. Please use torch::RegisterOperators instead.")
  RegisterOperators(const std::string& name, Implementation&& implementation) {
    op_(name, std::forward<Implementation>(implementation));
  }
 
  template <typename Implementation>
  C10_DEPRECATED_MESSAGE("torch::jit::RegisterOperators is deprecated. Please use torch::RegisterOperators instead.")
  RegisterOperators& op(
      const std::string& name,
      Implementation&& implementation) {
    op_(name, std::forward<Implementation>(implementation));
 
    return *this;
  }
 
private:
 
  template <typename Implementation>
  void op_(const std::string& name, Implementation&& implementation) {
    registrars_.emplace_back(std::make_shared<c10::RegisterOperators>(name, std::forward<Implementation>(implementation)));
  }
 
  // A c10::RegisterOperators instance is not copyable, so to make
  // torch::jit::RegisterOperators copyable, we use shared_ptrs.
  // We need to keep the c10::RegisterOperators instances around
  // because this is an RAII pattern. In the destructor, the registered
  // ops get de-registered.
  std::vector<std::shared_ptr<c10::RegisterOperators>> registrars_;
};
 
} // namespace jit
 
} // namespace torch