reid from https://github.com/michuanhaohao/reid-strong-baseline
zhangmeng
2020-01-17 f7c4a3cfd07adede3308f8d9d3d7315427d90a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#pragma once
 
#include <c10/util/ArrayRef.h>
#include <c10/util/sparse_bitset.h>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
 
#include <torch/csrc/WindowsTorchApiMacro.h>
 
// Uses a compressed index representation for faster comparisons
typedef c10::SparseBitVector<256> MemoryLocations;
namespace torch {
namespace jit {
 
struct Element;
struct Value;
 
// class MemoryDAG
//
// This class tracks the "A points to B" graph for all values. It is used by
// AliasDb to provide a higher-level API.
//
// We maintain a DAG where:
//   - Vertices (called "elements") represent values and
//     other aliasing entities (e.g. like the stuff inside a list)
//   - Edges represent a "points-to" relationship.
//
// Leaves in this DAG are entities that don't point to anything, and thus
// correspond to unique "memory locations".
//
// So, by traversing the "points-to" graph to the leaves, you can determine
// which memory locations an element may point to.
class TORCH_API MemoryDAG {
 public:
  // explicitly delete copy constructor because otherwise windows build is
  // confused for an exported class see
  // https://stackoverflow.com/a/51033485/105137
  MemoryDAG() {}
  MemoryDAG(const MemoryDAG&) = delete;
  MemoryDAG& operator=(const MemoryDAG&) = delete;
 
  // Make `from` point at `to`.
  void makePointerTo(Element* from, Element* to);
 
  void addToContainedElements(Element* contained, Element* container);
 
  // Make a fresh element (i.e. an element that doesn't point to anything) and
  // return it.
  Element* makeFreshValue(const Value* v);
 
  // Do `a` and `b` potentially share a memory location?
  bool mayAlias(const Element* a, const Element* b) const;
  bool mayAlias(Element* a, Element* b) const;
 
  // Does a hold reference to any memory that is stored in elem, or vice versa?
  bool mayContainAlias(const Element* a, const Element* b) const;
  bool mayContainAlias(Element* a, Element* b) const;
 
  bool mayContainAlias(
      const at::ArrayRef<Element*>& a,
      const at::ArrayRef<Element*>& b) const;
 
  // Converts from the compressed index representation
  const Element* fromIndex(unsigned x) const;
  Element* fromIndex(unsigned x);
 
 private:
  bool mayAliasImpl(const Element* a, const Element* b) const;
  bool mayContainAliasImpl(const Element* contained, const Element* container)
      const;
  void collectAllContainedMemoryLocations(
    const Element* elem, MemoryLocations& cont) const;
 
  std::vector<std::unique_ptr<Element>> indexToElementMap_;
};
 
// `Element` represents the vertex in the points-to graph. It represents
// anything that could have an aliasing relationship, mostly IR `Value`s, but
// also the "inside of a list", or wildcards.
struct Element {
  Element(MemoryDAG& dag_, const Value* value_, unsigned index_);
 
  // Reference to the owning DAG.
  MemoryDAG& dag;
  // Index into the owning DAG's bit vector that represents this element.
  unsigned index;
 
  // All elements that this element *may* point to. It's possible to have
  // multiple elements that you might point to due to control flow/complex ops
  MemoryLocations pointsTo;
  // Backreference for points-to.
  MemoryLocations pointedFrom;
 
  // Elements can contain other elements (e.g. List[Tensor])
  MemoryLocations containedElements;
 
  // Return the unique memory locations that `Element` might represent.
  TORCH_API const MemoryLocations& getMemoryLocations() const;
 
  // The value that this element corresponds to. May be null if this element
  // doesn't represent a first-class value.
  const Value* value = nullptr;
 
 private:
  // We do path compression to make repeated memory location queries faster.
  // An empty cache means it is invalidated (it can never be empty otherwise,
  // since every element must point to at least one memory location).
  mutable MemoryLocations cachedMemoryLocations_;
 
  enum class BfsDirection {
    POINTS_TO,
    POINTED_FROM,
  };
  // Do a breadth-first search over the graph, starting at `this` and
  // traversing in the direction `dir`.`fn` will be run on each element.
  void bfs(BfsDirection dir, MemoryLocations& res) const;
  friend class MemoryDAG;
};
 
} // namespace jit
} // namespace torch