#pragma once #include #include #include #include namespace torch { namespace jit { /** * Alias analysis pass. * * This pass produces an AliasDb that contains aliasing and mutation * information about the graph. Users can use this information to determine * whether mutations to the graph are safe, i.e. they don't reorder/change * nodes in a way that affects output. * * Every value with a mutable type (Tensors, Lists, Tuples, etc.) will be * associated with one or more "alias sets". If two values share an alias set, * that means they may alias, implying that a mutation to one value cannot be * reordered past a use of the other. Only reordering two reads of an alias set * is considered safe. * * There is a special alias set called the "wildcard set", which indicates that * we're not sure what this value may alias. To be conservative, we consider * the wildcard alias set as potentially aliasing any value. */ class AliasDb { public: TORCH_API explicit AliasDb(std::shared_ptr graph); TORCH_API ~AliasDb(); // There are limitations to what effects the alias analysis can track. Two // kinds of nodes may have untracked effects: // 1. Nodes that write to a value that may alias the graph inputs (since // the inputs can be used outside the graph). // 2. Nodes that write to something in the wildcard set. // // These nodes are considered not safe to eliminate or mutate under any // circumstances. bool writesToWildcard(Node* n) const; // Does `n` write to an alias of one of the values in `vs`? // if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks TORCH_API bool writesToAlias(Node* n, const ValueSet& vs) const; // Does `a` and `b` potentially share a memory location or do either // hold in memory any element that exists in the other TORCH_API bool mayContainAlias(Value* a, Value* b) const; // Do any values in group `a` share a memory location or hold in memory // any element that exists in group `b` TORCH_API bool mayContainAlias( const at::ArrayRef& a, const at::ArrayRef& b) const; // Do `a` and `b` potentially share a memory location? TORCH_API bool mayAlias(const Value* a, const Value* b) const; // Do any values in group `a` potentially share a memory location with any // value in group `b`? i.e. may they overlap? TORCH_API bool mayAlias(const ValueSet& a, const ValueSet& b) const; // Do any nodes write to an alias set input to `n`? TORCH_API bool hasInputWriters(const Node* n) const; // Do any nodes write to an alias set output by `n`? TORCH_API bool hasOutputWriters(const Node* n) const; // Do any nodes write to an alias set inputed/outputed by `n`? TORCH_API bool hasWriters(const Node* n) const; // Is the operation in-place? i.e. doesn't write anywhere but locations it // reads from. TORCH_API bool isMutable(Node* n) const; // Move 'n' (already in the graph) after 'movePoint' in the topological order. // // Tries to preserve value dependencies, so other nodes might be moved. We // make two gurantees about the postcondition of the node list: // - `n` is directly after `movePoint`. // - only nodes between `n` and `movePoint` have been moved. // // Returns `false` if it's impossible to move `n` after `MovePoint` without // violating dependencies, otherwise executes the move and returns `true` TORCH_API bool moveAfterTopologicallyValid(Node* n, Node* movePoint); TORCH_API bool moveBeforeTopologicallyValid(Node* n, Node* movePoint); bool couldMoveAfterTopologically(Node* n, Node* movePoint); bool couldMoveBeforeTopologically(Node* n, Node* movePoint); // For debugging: print alias db state to stdout TORCH_API void dump() const; TORCH_API std::string toString() const; private: // Helper for topologically-safe node moves. class WorkingSet; enum class MoveSide { BEFORE, AFTER }; bool tryMove(Node* toMove, Node* movePoint, MoveSide moveSide, bool dryRun); void move(Node* toMove, Node* movePoint, MoveSide moveSide); bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const; /** * Write and read internal API */ // Get all the values that `n` writes to. // NOTE: this only returns values directly written to, not aliases thereof // // if `recurseBlocks` is true, gather writes on the nodes in `n`s sub-blocks MemoryLocations getWrites(Node* n) const; void getWritesImpl(Node* n, MemoryLocations& ret) const; // Do any nodes write to `v`s memory location? TORCH_API bool hasWriters(const Value* v) const; // Register the fact that `n` writes to `v`. void registerWrite(const Value* v, Node* n); void registerWrite(const Element* e, Node* n); // Get all the values that `n` reads from. // if `recurseBlocks` is true, gather reads on the nodes in `n`s sub-blocks MemoryLocations getReads(Node* n) const; void getReadsImpl(Node* n, MemoryLocations& ret) const; /** * Wildcard methods */ // Register `v` as a wildcard value. void setWildcard(const Value* v); // Is the element a wildcard or an unhandled container type, // or does the element contain an element for which that's true bool cannotCheckAliasContainment(const Value* elem) const; // Is this a value which will not alias bool nonAliasingValue(const Value* elem) const; /** * Special analysis methods */ void analyze(const std::shared_ptr& graph); void analyze(Block* block); void analyze(Node* node); void analyzeImpl(Node* node); void analyzeIf(Node* node); void analyzeLoop(Node* node); void analyzeSubgraph(Node* node); void analyzeCreator(Node* node); void analyzeExtractor(Node* node); void analyzeChunk(Node* node); void analyzeBroadcastingChunk(Node* node); void analyzeFork(Node* node); void analyzeWait(Node* node); void analyzeGradOf(Node* node); void analyzeSetAttr(Node* node); void analyzeTupleConstruct(Node* node); void analyzeConservative(Node* node); void analyzeContainerConstruct(Node* node); bool tryRegisteredAnalysis(Node* node); /** * Alias manipulation methods */ void makeAllAlias(const std::vector& values); void makePointerTo(const Value* value, const Value* to); TORCH_API void addToContainedElements( const Value* element, const Value* container); void mapAliases(at::ArrayRef to, at::ArrayRef from); void giveFreshAlias(const Value* value); Element* getOrCreateElement(const Value* value); static bool shouldAnnotate(const Value* v); static bool shouldAnnotate(const TypePtr& type); static c10::optional getMutableTypeKind(const TypePtr& type); static bool isContainerType(const TypePtr& type); std::shared_ptr graph_; // The points-to graph that stores aliasing relationships std::unique_ptr memoryDAG_; // Mapping of values to MemoryDAG elements ska::flat_hash_map elementMap_; // All wildcard elements (one for each unique mutable type). std::map wildcardIndex_; Element* getWildcard(const TypePtr& type) const; Element* getOrCreateWildcard(const TypePtr& type); bool mayAliasWildcard(const Value* v) const; /** * State for tracking write info. */ // Map of nodes to the memory locations that they write to ska::flat_hash_map writeIndex_; // Set of all memory locations that may have been written to. mutable MemoryLocations writeCache_; mutable bool isWriteCacheStale_ = true; void rebuildWriteCache() const; std::string getElementName(const Element* e) const; }; // Used to assert that unschematized operators have an analysis method written TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym); } // namespace jit } // namespace torch