#pragma once #include namespace torch { namespace jit { struct IfView { explicit IfView(Node* node) : node_(node) { AT_ASSERT(node->kind() == ::c10::prim::If); } Value* cond() const { return node_->input(0); } Block* thenBlock() const { return node_->blocks().at(0); } Block* elseBlock() const { return node_->blocks().at(1); } ArrayRef thenOutputs() const { return thenBlock()->outputs(); } ArrayRef elseOutputs() const { return elseBlock()->outputs(); } ArrayRef outputs() const { return node_->outputs(); } Node* node() const { return node_; } operator Node*() const { return node_; } void permuteOutputs(const std::vector& new_output_order) { node_->permuteOutputs(new_output_order); thenBlock()->permuteOutputs(new_output_order); elseBlock()->permuteOutputs(new_output_order); } private: Node* node_; }; struct LoopView { explicit LoopView(Node* node) : node_(node) { AT_ASSERT( node->kind() == ::c10::prim::Loop || node->kind() == ::c10::onnx::Loop); } Block* bodyBlock() const { return node_->blocks().at(0); } Value* cond() const { return node_->input(0); } Value* maxTripCount() const { return node_->input(0); } Value* inputCond() const { return node_->input(1); } Value* nextCond() const { return bodyBlock()->outputs().at(0); } Value* currentTripCount() const { return bodyBlock()->inputs().at(0); } ArrayRef carriedInputs() const { // skip trip count and cond return node_->inputs().slice(2); } ArrayRef carriedInputsWithCond() const { // skip trip count and cond return node_->inputs().slice(1); } ArrayRef carriedOutputs() const { return node_->outputs(); } ArrayRef bodyCarriedInputs() const { // skip trip count and cond return bodyBlock()->inputs().slice(1); } ArrayRef bodyCarriedOutputs() const { return bodyBlock()->outputs().slice(1); } Node* node() const { return node_; } operator Node*() const { return node_; } void permuteLoopCarried(const std::vector& new_output_order) { node_->permuteOutputs(new_output_order); // skip trip count and cond node_->permuteInputs(adjustIndices(2, new_output_order)); auto adjusted_block_order = adjustIndices(1, new_output_order); bodyBlock()->permuteOutputs(adjusted_block_order); bodyBlock()->permuteInputs(adjusted_block_order); } void replaceMaxTripCount(Value* new_max_trip_count) { node_->replaceInput(0, new_max_trip_count); } void replaceInputCondition(Value* new_input_condition) { node_->replaceInput(1, new_input_condition); } // our way of encoding loops makes them difficult to turn back into python // syntax. we have to check properties of the condition and trip count inputs // to figure out which one it initially was. ModifiedLoops are not directly // mappable to either For or While enum LoopType { While, For, ModifiedLoop }; LoopType loopType() { auto trip_count = toIValue(maxTripCount()); auto cond_input = toIValue(inputCond()); auto cond_next = toIValue(nextCond()); bool condition_is_always_true = cond_input && cond_input->toBool() && cond_next && cond_next->toBool(); bool trip_count_is_specified = !trip_count || // trip is not a constant trip_count->toInt() != std::numeric_limits::max() || // it is a constant but not // the default one currentTripCount()->uses().size() > 0; // it is actually being used in the body. if (condition_is_always_true) { // if the trip count was not specified this was a user-written while True: return trip_count_is_specified ? For : While; } else { if (trip_count_is_specified) { return ModifiedLoop; } return While; } } private: Node* node_; // adjust index_ordering by adding indices 0 - thorugh adjust, and // incrementing all existing inputs by adjust static std::vector adjustIndices( size_t adjust, const std::vector& index_ordering) { std::vector adjusted; adjusted.reserve(adjust + index_ordering.size()); for (size_t i = 0; i < adjust; ++i) { adjusted.push_back(i); } for (auto index : index_ordering) { adjusted.push_back(index + adjust); } return adjusted; } }; } // namespace jit } // namespace torch