#pragma once #include #include #include #include namespace c10 { /** * class AliasInfo * * Data structure to hold aliasing information for an `Argument`. They can be * nested to represent aliasing information on contained types. * * There is a `beforeSet` which describes the aliasing information before the * operator executes, and an `afterSet` that describes aliasing info * after execution. */ class AliasInfo { public: // Symbol for the set that can alias anything static Symbol wildcardSet() { static const Symbol wc = Symbol::fromQualString("alias::*"); return wc; } void setIsWrite(bool isWrite) { isWrite_ = isWrite; } bool isWrite() const { return isWrite_; } void addBeforeSet(Symbol aliasSet) { beforeSets_.insert(aliasSet); } void addAfterSet(Symbol aliasSet) { afterSets_.insert(aliasSet); } const std::unordered_set& beforeSets() const { return beforeSets_; } const std::unordered_set& afterSets() const { return afterSets_; } Symbol beforeSet() const { AT_ASSERT(beforeSets_.size() == 1); return *beforeSets_.begin(); } bool isWildcardBefore() const { return beforeSets_.count(wildcardSet()) != 0; } bool isWildcardAfter() const { return afterSets_.count(wildcardSet()) != 0; } // the alias info for the contained types of the type // e.g. if this is an annotation on List[T], `sets` refers to // the alias sets that the list may be in // while containedTypes()[0] refers to the sets that members of the list // may be in void addContainedType(AliasInfo aliasInfo) { containedTypes_.push_back(std::move(aliasInfo)); } const std::vector& containedTypes() const { return containedTypes_; } private: std::unordered_set beforeSets_; std::unordered_set afterSets_; std::vector containedTypes_; bool isWrite_ = false; }; inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) { return lhs.isWrite() == rhs.isWrite() && lhs.beforeSets() == rhs.beforeSets() && lhs.afterSets() == rhs.afterSets() && lhs.containedTypes() == rhs.containedTypes(); } // this does match the way things are represented in the schema inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) { out << "("; bool first = true; for (const auto& set : aliasInfo.beforeSets()) { if (first) { first = false; } else { out << "|"; } out << set.toUnqualString(); } if (aliasInfo.isWrite()) { out << "!"; } if (aliasInfo.beforeSets() != aliasInfo.afterSets()) { out << " -> "; first = true; for (const auto& set : aliasInfo.afterSets()) { if (first) { first = false; } else { out << "|"; } out << set.toUnqualString(); } } out << ")"; return out; } } // namespace c10