#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // Forward declare, the real meat is in python_ir.cpp template class THPPointer; using THPObjectPtr = THPPointer; using pyobj_list = std::vector; namespace torch { namespace jit { using ::c10::Argument; using ::c10::FunctionSchema; using ::c10::Symbol; using ::c10::ivalue::Shared; using ::c10::IValue; using ::c10::ivalue::Future; using ::c10::ivalue::ConstantString; #define C10_USING(T) using ::c10::T; C10_FORALL_TYPES(C10_USING) #undef C10_USING #define C10_USING(T) using ::c10::T##Ptr; C10_FORALL_TYPES(C10_USING) #undef C10_USING using ::c10::Type; using ::c10::TypeEnv; using ::c10::TypePtr; using ::c10::getTypePtr; using ::c10::MatchTypeReturn; using ::c10::TypeKind; using ::c10::fmap; namespace prim { using namespace ::c10::prim; } namespace attr { using namespace ::c10::attr; } namespace aten { using namespace ::c10::aten; } struct Function; namespace script { struct MatchedSchema; } // namespace script // Graph represents one "function" of computation. // It uses a simple ownership model where the graph owns all the nodes inside // it. All references inside the graph are raw pointers. Destroying the Graph // will invalidate any pointers to nodes in the graph. struct Graph; // Node is the base class of the IR graph. It represents one computation // and dependencies on a list of Values. The "prim-ops", so to speak. struct Node; // A Value represents an input or output to node that is either a // Tensor or an opaque Handle object, as determined by type(). struct Value; TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g); TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n); // A list of nodes, with inputs and outputs struct Block; // Each use is represented by this type, see Node::uses() // 'user' is the consumer of the value, offset is the index into // 'user's input this where the produces will be found. struct Use { Use(Node* user, size_t offset) : user(user), offset(offset) {} Node* user; size_t offset; bool operator==(const Use& b) { return user == b.user && offset == b.offset; } }; // Note [User node does not uniquely identify use] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // A while back, we wrote some code manipulating uses that looked like this: // // for (auto& use : used_val->uses_) { // if (use.user == this_node) { // use.offset += 1; // break; // } // } // // This code is trying to find a particular use (our node's use) to update it. // However, it's wrong: there may be *multiple* uses of a value %x in a node, // as might be the case in this IR: // // %y = Add %x %x // // In this case, there are two uses of %x whose user is the node 'Add %x %x'. // So, "use induced by this node" is not a well-formed concept. // // If you are looking for "use induced by an input", it's best to use // findUseForInput() to get it. // the list types are intentionally simple, but we type-def // them here so if we need to change them, refactoring will be easier using node_list = std::vector; using value_list = std::vector; using use_list = std::vector; template using ArrayRef = at::ArrayRef; using NodeKind = Symbol; using topo_position_t = int64_t; using ValueSet = std::unordered_set; struct Value { TH_DISALLOW_COPY_AND_ASSIGN(Value); Value(Node* node_, size_t offset_); private: friend struct Node; friend struct Graph; Node* node_; size_t offset_; size_t unique_ = 0; // unique id use_list uses_; std::string unique_name_; TypePtr type_; public: Value* setType(TypePtr type); TORCH_API void inferTypeFrom(const at::Tensor& output); const TypePtr& type() const { AT_ASSERT(type_ != nullptr); return type_; } bool requires_grad() const { return type()->requires_grad(); } bool isCompleteTensor() const { if (auto pt = type()->cast()) { return pt->isComplete(); } return false; } TORCH_API bool mustBeNone() const; TORCH_API bool mustNotBeNone() const; size_t unique() const { return unique_; } bool hasDebugName() const { return !unique_name_.empty(); } static bool isValidName(const std::string& name); TORCH_API Value* setDebugName(const std::string& name); std::string debugName() const { if (hasDebugName()) { return unique_name_; } return std::to_string(unique()); } TORCH_API std::string debugNameBase() const; Node* node() { return node_; } size_t offset() const { return offset_; } void setOffset(size_t offset) { offset_ = offset; } const Node* node() const { return node_; } Graph* owningGraph(); const Graph* owningGraph() const; // TODO: make this more const correct const use_list& uses() const { return uses_; } bool hasUses() const { return !uses().empty(); } TORCH_API void replaceFirstUseWith(Value* newValue); // Replaces all uses of this value with 'newValue'. // // Given: %3 = f(%1, %2) // %4 = g(%3) // %5 = h(%3, %3) // Execute: %3.replaceAllUsesWith(%6) // Result: %3 = f(%1, %2) // %4 = g(%6) // %5 = h(%6, %6) TORCH_API void replaceAllUsesWith(Value* newValue); TORCH_API Value* copyMetadata(Value* from); }; struct TORCH_API Node { TH_DISALLOW_COPY_AND_ASSIGN(Node); friend struct Graph; friend struct Block; friend struct Value; friend graph_node_list; friend const_graph_node_list; friend graph_node_list_iterator; friend const_graph_node_list_iterator; private: const NodeKind kind_; std::vector inputs_; std::vector outputs_; // subblocks std::vector blocks_; Graph* graph_; Block* owning_block_; c10::optional source_range_; ScopePtr scope_; // Assumes FunctionSchemas are persistent, so we don't manage their lifetime. // This field is effective a cache that's populated on attribute lookups and // invalidated every time we perform an operation that could potentially // change the schema. note: mutable because schema_ is effectively a cache mutable const Operator* op_; topo_position_t topo_position_ = 0; protected: Node(Graph* graph_, NodeKind kind_); // defined after graph public: // each node but Return/Param // is associated with exactly one place in the node list... // of the graph_ // this circular is a doubly-linked list, the Return node is used as the // sentinel for the beginning and end of the list such that the list never has // null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev // pointer using an array to allow the same iterator class for forward and // reverse node lists This list represents a topological sort Node* next_in_graph[2] = {nullptr, nullptr}; Node*& next() { return next_in_graph[kNextDirection]; } Node*& prev() { return next_in_graph[kPrevDirection]; } Node* const& next() const { return next_in_graph[kNextDirection]; } Node* const& prev() const { return next_in_graph[kPrevDirection]; } NodeKind kind() const { return kind_; } Node* setSourceRange(SourceRange r) { source_range_ = std::move(r); return this; } SourceRange sourceRange() const; Graph* owningGraph() { return graph_; } const Graph* owningGraph() const { return graph_; } Block* owningBlock() { return owning_block_; } const Block* owningBlock() const { return owning_block_; } ScopePtr scope() { return scope_; } void setScope(ScopePtr scope) { scope_ = std::move(scope); } std::string scopeName() const { if (!scope_) { return ""; } return scope_->namesFromRoot(); } // NB: This returns an ArrayRef; that means that it will // get invalidated if you resize inputs (e.g., using addInput) // We can't return a std::vector& because there's no // way to soundly cast to std::vector (an insane // implementation of std::vector could make this representationally // different.) at::ArrayRef inputs() { return inputs_; } at::ArrayRef inputs() const { // Vectors are not convertible in const-ness of elements, but // raw pointers are. return {inputs_.data(), inputs_.size()}; } // NB: This returns an ArrayRef; that means that it will // get invalidated if you resize inputs (e.g., using addInput) // We can't return a std::vector& because there's no // way to soundly cast to std::vector (an insane // implementation of std::vector could make this representationally // different.) at::ArrayRef outputs() { return outputs_; } at::ArrayRef outputs() const { // Vectors are not convertible in const-ness of elements, but // raw pointers are. return {outputs_.data(), outputs_.size()}; } Value* output(size_t i) const { return outputs_.at(i); } bool hasUses() const { for (auto o : outputs()) { if (!o->uses().empty()) { return true; } } return false; } void replaceAllUsesWith(Node* n); // lots of things like chunk have a single input or single output, so we have // a helper to make accessing it easier Value* input() { AT_ASSERT(inputs_.size() == 1); return inputs_.at(0); } Value* output() { AT_ASSERT(outputs_.size() == 1); return outputs_.at(0); } const Value* output() const { AT_ASSERT(outputs_.size() == 1); return outputs_.at(0); } const Value* input() const { AT_ASSERT(inputs_.size() == 1); return inputs_.at(0); } // Access a particular input. This is a checked index. Value* input(size_t i) const { return inputs_.at(i); } Value* namedInput(Symbol name) const; c10::optional get(Symbol name) const; template c10::optional get(Symbol name) const { if (auto v = get(name)) { return v->template to(); } return c10::nullopt; } // Returns true if the value of input name is statically known bool is_constant(Symbol name) const { return static_cast(get(name)); } bool mustBeNone() const; bool isNondeterministic() const; bool hasSideEffects() const; // Graphs // Note [Topological invariant] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // We always maintain an up-to-date topological ordering of all nodes via // the next()/prev() links. All transformations to graphs must preserve // this topological ordering: for example, it is only valid to 'addInput' // with an input which is topologically before the current node. // // Usually, it is obvious whether or not topological order is maintained; // for example, if you are adding nodes to the end of the topsort, it's // impossible for them to refer to inputs that are not in the topsort. // If it is not obvious, please comment accordingly. // Add 'node' as an input to 'this' at the end of existing // arguments. Returns the added node for ease of chaining. // // Given: %3 = f(%1, %2) // Execute: %3.addInput(%4) // Result: %3 = f(%1, %2, %4) Value* addInput(Value* value); // Add 'value' as an input to 'this' at the specified position in the // arguments. Returns the added value for ease of chaining. Value* insertInput(size_t i, Value* value); // Replace the input of 'this' at position 'i' with // 'newValue', returning the old node. // // Given: %3 = f(%1, %2) // Execute: %3.replaceInput(1, %4) // Result: %3 = f(%1, %4) Value* replaceInput(size_t i, Value* newValue); // Replace all occurrences of 'from' in the inputs of this // node with 'to'. Corresponds to llvm's replaceUsesOfWith. // // Given: %3 = f(%1, %2, %1) // Execute: %3.replaceInputWith(%1, %4) // Result: %3 = f(%4, %2, %4) void replaceInputWith(Value* from, Value* to); Value* addOutput(); Value* insertOutput(size_t i); void eraseOutput(size_t i); Block* addBlock(); void eraseBlock(size_t i); // Each Node can have a list of subblocks. These are used to define structured // nested control flow operators such as If and Loop. // The meaning of a block is specific to the kind of node it is in, but // all blocks share these semantics: // * Nested lexical scoping: If a node 'Parent' has a subblock which contains // a node 'Child', Child can use any value that was in scope for the Parent // node in addition to any values defined before 'Child' in the subblock. // * The list of inputs to the block are in scope for the duration of the // block // * the outputs of the Parent node are not in scope for the subblocks // Typically the inputs to a block that represents control flow act as // as the equivalents phi-nodes in standard SSA form, // defining a new Value to represent any term that has multiple // definitions depending on how control flowed. Outputs of the node containing // control flow serve a similiar purpose defining new values for variables // that would have different definitions depending on which way control // flowed. at::ArrayRef blocks() { return blocks_; } at::ArrayRef blocks() const { // Vectors are not convertible in const-ness of elements, but // raw pointers are. return {blocks_.data(), blocks_.size()}; } // Is 'this' before 'n' in the topological order? bool isBefore(const Node* n) const; // Is 'this' after 'n' in the topological order? bool isAfter(const Node* n) const; // Insert unattached 'this' node before 'n' in the topological order. // Returns this (for chaining). // // Given: %3 = f(%1, %2) // %4 = g(%3) // and unattached: %5 = h(%1) // Execute: %5.insertBefore(%4) // Result: %3 = f(%1, %2) // %5 = h(%1) // %4 = g(%3) Node* insertBefore(Node* n); // Insert unattached 'this' node after 'n' in the topological order. // Returns this (for chaining). // // Given: %3 = f(%1, %2) // %4 = g(%3) // and unattached: %5 = h(%1) // Execute: %5.insertAfter(%4) // Result: %3 = f(%1, %2) // %4 = g(%3) // %5 = h(%1) Node* insertAfter(Node* n); // Move 'this' (already in the graph) after 'n' in the topological order. // // NOTE: Does not check that value dependencies are preserved, see // AliasDb::moveAfterTopologicallyValid // // Given: %2 = f(%1) // %3 = g(%1) // Execute: %2.moveAfter(%3) // Result: %3 = g(%1) // %2 = f(%1) // void moveAfter(Node* n); // Move a node 'n' (already in the graph) before 'this' in the topological // order. // // NOTE: Does not check that value dependencies are preserved, see // AliasDb::moveBeforeTopologicallyValid // // Given: %2 = f(%1) // %3 = g(%1) // Execute: %3.moveBefore(%2) // Result: %3 = g(%1) // %2 = f(%1) void moveBefore(Node* n); // Remove the input at 'i' from this node. // // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling // removeInput. // // Given: %3 = f(%1, %2) // Execute: %3.removeInput(1) // Result: %3 = f(%1) void removeInput(size_t i); // Remove all inputs from a node. // // Given: %3 = f(%1, %2) // Execute: %3.removeAllInputs() // Result: %3 = f() void removeAllInputs(); // Rearrange the ordering of inputs or outputs of a node // Given: %3 = f(%1, %2) // Execute: %3.permuteInputs({1, 0}) // Result: %3 = f(%2, %1) // Each index must appear exactly once void permuteInputs(const std::vector& new_inputs); void permuteOutputs(const std::vector& new_inputs); // iterators of the node list starting at this node // useful for resuming a search starting at this node inline graph_node_list_iterator iterator() { return {this, 0}; } inline graph_node_list_iterator reverseIterator() { return iterator().reverse(); } inline const_graph_node_list_iterator iterator() const { return {this, 0}; } inline const_graph_node_list_iterator reverseIterator() const { return iterator().reverse(); } // Remove 'this' from the instruction list and deallocate it. // // Invariant: no outputs of 'this' may have any uses. // // Given: %2 = f(%1) // %3 = g(%1) // Execute: %2.destroy() // Result: %3 = g(%1) void destroy(); // Dynamically cast this node to the subclass indicated by the // template variable, returning nullptr if the cast is invalid.. // // Example usage: if(auto s = n.cast