reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
/** This file defines API for pattern-based subgraph rewrites.
 *
 * The API can be used for finding concrete patterns in the model and replacing
 * the corresponding subgraphs with another subgraph. A special case of such
 * rewrites is fusion, where the new subgraph consists of just a single node.
 *
 * There is a default set of most-common patterns that everyone could use, or
 * alternatively an arbitrary pattern can be registered.
 */
#pragma once
 
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/script/module.h>
 
#include <functional>
#include <unordered_set>
#include <vector>
 
namespace torch {
namespace jit {
 
// Forward declarations.
struct RewritePatternDescr;
struct Match;
 
/** Run pattern-based subgraph rewrites on all methods in the module.
 *
 * This pass will go through all methods in the module and try to replace all
 * recognized patterns (see SubgraphRewriter::RegisterDefaultPatterns for the
 * list of these patterns).
 */
TORCH_API script::Module PatternBasedRewrite(const script::Module& module);
 
/** A class implementing API for pattern-based subgraph rewrites.
 *
 * To perform pattern-based subgraph rewrites on a module using this API, one
 * needs to crete an object of such class, register rewrite patterns and run the
 * transformation pass (`runOnModule`).
 *
 * To use standard patterns, one could use `RegisterDefaultPatterns`.
 *
 * To enable rewrites of custom patterns, they must be registered with
 * `RegisterRewritePattern`.
 */
class TORCH_API SubgraphRewriter {
 public:
  // Run pattern-based subgraph rewrite pass on the module.
  script::Module runOnModule(const script::Module& module);
 
  // Run pattern-based subgraph rewrite pass on the graph (used in testing).
  // filter is a function that does extra filtering on the match, if it returns
  // false for a given Match, we'll skip the match
  // filter function takes a `Match` and a value map from parsing the pattern graph
  // since we need to do extra filtering on the matched result but we need to refer
  // to the values in the matched result through the values in pattern graph.
  void runOnGraph(
      std::shared_ptr<Graph>& graph,
      const std::function<
          bool(const Match&, const std::unordered_map<std::string, Value*>&)>&
          filter =
              [](const Match&, const std::unordered_map<std::string, Value*>&) {
                return true;
              });
 
  // Register standard rewrite patterns.
  void RegisterDefaultPatterns();
 
  /** Register a custom rewrite pattern.
   *
   * The method takes two parameters specifying the pattern:
   * \p PATTERN - IR string representing the pattern subgraph.
   * \p REPLACEMENT - IR stringn representing the replacement subgraph.
   *
   * See examples of pattern registering in `RegisterDefaultPatterns`.
   */
  void RegisterRewritePattern(
      const std::string& pattern,
      const std::string& replacement);
 
 private:
  std::vector<RewritePatternDescr> patterns_;
  std::unordered_set<Node*> nodes_to_delete_;
 
  void rewriteSinglePatternOnGraph(
      std::shared_ptr<Graph>& graph,
      const RewritePatternDescr& pattern,
      const std::function<
          bool(const Match&, const std::unordered_map<std::string, Value*>&)>&
          filter =
              [](const Match&, const std::unordered_map<std::string, Value*>&) {
                return true;
              });
  bool overlapsWithPreviousMatches(const Match* match);
};
 
/** Rewrite pattern descriptor.
 *
 * This structure is used in implementation of `SubgraphRewriter` and not
 * supposed to be used externally.
 */
struct RewritePatternDescr {
  std::string pattern;
  std::string replacement;
};
 
} // namespace jit
} // namespace torch