#pragma once #include #include #include #include #include namespace at { class Tensor; } namespace c10 { struct IValue; struct OperatorName; } // namespace c10 namespace torch { namespace jit { // The interpreter run Graphs with Tensor inputs and Tensor outputs // a separate component in the autograd handles unwrapping and wrapping // variable objects for use in the interpreter. struct Node; struct GraphExecutor; struct CodeImpl; struct InterpreterStateImpl; struct Graph; struct Node; struct Instruction; using Stack = std::vector; using c10::ivalue::Future; struct TORCH_API Code { Code() : pImpl(nullptr) {} explicit Code(const std::shared_ptr& graph); ~Code(); const std::vector& grad_executors(); explicit operator bool() const { return pImpl != nullptr; } size_t num_inputs() const; size_t num_outputs() const; const std::vector& constant_table() const; const std::vector& instructions() const; const std::vector& opname_table() const; size_t register_size() const; private: std::shared_ptr pImpl; friend struct InterpreterStateImpl; friend std::ostream& operator<<(std::ostream& out, const Code& code); }; struct InterpreterState { TORCH_API InterpreterState(const Code& code); TORCH_API void run(Stack& stack); c10::intrusive_ptr runAsync(Stack& stack); c10::intrusive_ptr getFuture(); TORCH_API ~InterpreterState(); private: InterpreterState(c10::intrusive_ptr pImpl); // Ideally we should use c10::intrusive_ptr for pImpl; // but intrusive_ptr requires full definition of InterpreterStateImpl, // which we need to hide in the header. c10::intrusive_ptr pImpl; friend struct InterpreterStateImpl; }; // Created by wait() struct Suspend : public std::exception { const char* what() const noexcept override { return "Suspend"; } explicit Suspend(c10::intrusive_ptr future_) : future(std::move(future_)) {} c10::intrusive_ptr future; }; struct InterpreterContinuation { InterpreterContinuation( InterpreterState state_, Stack stack_, bool grad_mode_enabled_) : state(state_), stack(std::move(stack_)), grad_mode_enabled(grad_mode_enabled_) {} void operator()(); private: InterpreterState state; Stack stack; bool grad_mode_enabled; }; } // namespace jit } // namespace torch