#pragma once
|
|
// Engine implements backpropagation from output variables and their gradients
|
// to "root" variables (variables created by the user with requires_grad=True).
|
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
#include <torch/csrc/autograd/function.h>
|
#include <torch/csrc/autograd/input_buffer.h>
|
#include <torch/csrc/autograd/anomaly_mode.h>
|
|
#include <deque>
|
#include <exception>
|
#include <functional>
|
#include <memory>
|
#include <queue>
|
#include <unordered_map>
|
#include <utility>
|
#include <vector>
|
#include <thread>
|
|
namespace torch { namespace autograd {
|
struct ReadyQueue;
|
struct NodeTask;
|
struct GraphTask;
|
}} // namespace torch::autograd
|
|
namespace torch { namespace autograd {
|
// A single instance of this struct should be created through the whole process lifetime.
|
// The worker thread creation logic and Engine's destructor rely on this.
|
struct TORCH_API Engine {
|
/// Returns a reference to a static `Engine` instance.
|
static Engine& get_default_engine();
|
|
Engine();
|
virtual ~Engine();
|
|
using ready_queue_type = std::deque<std::pair<std::shared_ptr<Node>, InputBuffer>>;
|
using dependencies_type = std::unordered_map<Node*, int>;
|
|
// Given a list of (Node, input number) pairs computes the value of the graph
|
// by following next_edge references.
|
virtual variable_list execute(
|
const edge_list& roots,
|
const variable_list& inputs,
|
bool keep_graph,
|
bool create_graph,
|
const edge_list& outputs = {});
|
virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() {
|
return nullptr;
|
}
|
|
void queue_callback(std::function<void()> callback);
|
|
bool is_checkpoint_valid();
|
|
protected:
|
void compute_dependencies(Node* root, GraphTask& task);
|
void evaluate_function(NodeTask& task);
|
ReadyQueue& ready_queue(at::Device device);
|
ReadyQueue& ready_queue_by_index(int device_index);
|
void start_threads();
|
virtual void thread_init(int device);
|
virtual void thread_main(GraphTask *graph_task);
|
virtual void thread_on_exception(NodeTask& task, std::exception& e);
|
void reentrant_thread_init();
|
void add_thread_pool_task(GraphTask *graph_task);
|
void set_device(int device);
|
|
// Ensures ready_queues_ are initialized only once
|
std::once_flag start_threads_flag_;
|
// Safe to read ready_queues_ without synchronization after intialization
|
std::vector<std::shared_ptr<ReadyQueue>> ready_queues_;
|
std::vector<std::function<void()>> final_callbacks_;
|
// To protect reads and writes to final_callbacks_
|
std::mutex post_callbacks_lock_;
|
// How many nested reentrant calls are allowed until a new thread is used
|
int max_recursion_depth_;
|
|
struct ThreadPoolShared {
|
// Data structures used by the threads for executing reentrant backwards
|
// tasks. See Note [Reentrant backwards]
|
// Number of available threads for processing new GraphTasks.
|
unsigned int num_workers_;
|
// The threads will wait on work_ to be notified of GraphTasks
|
std::condition_variable work_;
|
// To protect reads and writes to graphtask_queue_ and num_workers_
|
// and for synchronizing creating new threads when needed
|
std::mutex mutex_;
|
// Workers will process the GraphTasks added to this queue. A GraphTask is
|
// allocated inside Engine::execute and lives for the duration of execute
|
std::queue<GraphTask*> graphtasks_queue_;
|
|
ThreadPoolShared() : num_workers_(0) {}
|
};
|
|
// Temporary workaround until shutting down threads is done
|
// We need shared ownership of all these objects because the threads are leaked
|
// when Engine shuts down, so there may be threads waiting on work_
|
// for the graphtasks_queue_ to be nonempty.
|
std::shared_ptr<ThreadPoolShared> thread_pool_shared_;
|
};
|
|
// allow python_engine to override the default engine when it loads
|
using EngineStub = Engine& (*)();
|
TORCH_API void set_default_engine_stub(EngineStub stub);
|
|
}} // namespace torch::autograd
|