#pragma once #include #include #include #include #include #include #include // Uses a compressed index representation for faster comparisons typedef c10::SparseBitVector<256> MemoryLocations; namespace torch { namespace jit { struct Element; struct Value; // class MemoryDAG // // This class tracks the "A points to B" graph for all values. It is used by // AliasDb to provide a higher-level API. // // We maintain a DAG where: // - Vertices (called "elements") represent values and // other aliasing entities (e.g. like the stuff inside a list) // - Edges represent a "points-to" relationship. // // Leaves in this DAG are entities that don't point to anything, and thus // correspond to unique "memory locations". // // So, by traversing the "points-to" graph to the leaves, you can determine // which memory locations an element may point to. class TORCH_API MemoryDAG { public: // explicitly delete copy constructor because otherwise windows build is // confused for an exported class see // https://stackoverflow.com/a/51033485/105137 MemoryDAG() {} MemoryDAG(const MemoryDAG&) = delete; MemoryDAG& operator=(const MemoryDAG&) = delete; // Make `from` point at `to`. void makePointerTo(Element* from, Element* to); void addToContainedElements(Element* contained, Element* container); // Make a fresh element (i.e. an element that doesn't point to anything) and // return it. Element* makeFreshValue(const Value* v); // Do `a` and `b` potentially share a memory location? bool mayAlias(const Element* a, const Element* b) const; bool mayAlias(Element* a, Element* b) const; // Does a hold reference to any memory that is stored in elem, or vice versa? bool mayContainAlias(const Element* a, const Element* b) const; bool mayContainAlias(Element* a, Element* b) const; bool mayContainAlias( const at::ArrayRef& a, const at::ArrayRef& b) const; // Converts from the compressed index representation const Element* fromIndex(unsigned x) const; Element* fromIndex(unsigned x); private: bool mayAliasImpl(const Element* a, const Element* b) const; bool mayContainAliasImpl(const Element* contained, const Element* container) const; void collectAllContainedMemoryLocations( const Element* elem, MemoryLocations& cont) const; std::vector> indexToElementMap_; }; // `Element` represents the vertex in the points-to graph. It represents // anything that could have an aliasing relationship, mostly IR `Value`s, but // also the "inside of a list", or wildcards. struct Element { Element(MemoryDAG& dag_, const Value* value_, unsigned index_); // Reference to the owning DAG. MemoryDAG& dag; // Index into the owning DAG's bit vector that represents this element. unsigned index; // All elements that this element *may* point to. It's possible to have // multiple elements that you might point to due to control flow/complex ops MemoryLocations pointsTo; // Backreference for points-to. MemoryLocations pointedFrom; // Elements can contain other elements (e.g. List[Tensor]) MemoryLocations containedElements; // Return the unique memory locations that `Element` might represent. TORCH_API const MemoryLocations& getMemoryLocations() const; // The value that this element corresponds to. May be null if this element // doesn't represent a first-class value. const Value* value = nullptr; private: // We do path compression to make repeated memory location queries faster. // An empty cache means it is invalidated (it can never be empty otherwise, // since every element must point to at least one memory location). mutable MemoryLocations cachedMemoryLocations_; enum class BfsDirection { POINTS_TO, POINTED_FROM, }; // Do a breadth-first search over the graph, starting at `this` and // traversing in the direction `dir`.`fn` will be run on each element. void bfs(BfsDirection dir, MemoryLocations& res) const; friend class MemoryDAG; }; } // namespace jit } // namespace torch