#pragma once
|
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
#include <torch/csrc/jit/ir.h>
|
|
#include <ATen/ATen.h>
|
|
#include <memory>
|
#include <vector>
|
|
namespace torch {
|
namespace jit {
|
|
using value_list = std::vector<Value*>;
|
// clang-format off
|
// Example showcasing how Gradient is constructed:
|
//
|
// Let's assume we have a function f, `m` and `n` do not require grad
|
// (`n` can depend only on `m`):
|
// y, n = f(x, m)
|
//
|
// Now, let's assume that the reverse of f (called f') needs to use values of `x`, `t` and `y`.
|
// `t` is an intermediate value produced in the body of f, and let's assume that it requires
|
// grad too.
|
//
|
// In this case differentiate(f) will return this:
|
// y, n, t = f(x, m) // `t` is appended to the output list
|
// dx = f'(dy, dt, x, t, y) // No `dm` or `dn` because they do not require gradient
|
// // All needed values from f are prepended to the input list
|
//
|
// f_real_outputs = 2 // Only first two outputs were present in f originally
|
// df_input_vjps = {0, 2} // i.e. connect grad_fn of y and t variables produced by f,
|
// y t // with y's output_nr = 0 and t's output_nr = 1
|
// df_input_captures = {I0, O2, O0} // Order matches the prefix of inputs to df
|
// x t y
|
// df_output_vjps = {0} // i.e. connect next_edge[0] of grad_fn to x's (grad_fn, output_nr).
|
//
|
// Terminology: vjp = vector-jacobian product
|
// clang-format on
|
|
struct Gradient {
|
explicit operator bool() const {
|
return df != nullptr;
|
}
|
std::shared_ptr<Graph> f;
|
std::shared_ptr<Graph> df;
|
|
// Describes how to construct outputs of f from what its graph will return.
|
// This is necessary because some trailing outputs are intermediates produced
|
// only to be saved for df (and should be ignored).
|
size_t f_real_outputs = 0; // initialized for safety.
|
|
// df inputs are split into two sections: vjps (aka grad_outputs) and
|
// captures. VJPs are "seeds" for the gradient computation given for each
|
// input capture of an Output kind. Captures are values the need to be saved
|
// when f is run. We handle inputs specially, because this allows us to avoid
|
// adding extra vjps as df inputs.
|
|
std::vector<size_t> df_input_vjps; // Offsets into f's outputs.
|
// capture can come from inputs or outputs
|
std::vector<size_t> df_input_captured_inputs; // Offsets into f's inputs
|
std::vector<size_t> df_input_captured_outputs; // Offsets into f's outputs
|
|
// df will produce vjps for a subset of inputs of f that required grad.
|
// df_output_vjps[idx] == inp_idx means that idx-th output of df produces a
|
// vjp for inp_idx-th input of f.
|
std::vector<size_t> df_output_vjps; // Offsets into f's inputs.
|
|
// How to use gradient to implement a differentiable autograd function:
|
// When running f:
|
// - Unwrap input Variables
|
// - Run f's graph
|
// - Create grad_fn
|
// - Wrap outputs in Variables (assume we have a tensor_outputs array):
|
// outputs = map(Variable, tensor_output)
|
// for i, offset in enumerate(df_input_vjps):
|
// outputs[offset].set_grad_fn(grad_fn, output_nr=i)
|
// - Use df_output_vjps to connect next_edges of grad_fn:
|
// for idx in df_output_vjps:
|
// grad_fn.add_next_edge(inputs[idx].gradient_edge())
|
// - Save captures for df (care needs to be taken to use SavedVariables for
|
// inputs and outputs that we will actually return)
|
// - Return outputs[:f_real_outputs]
|
//
|
// When running df:
|
// - Concatenate received vjps and captured Variables
|
// - Interpret df
|
// - Wrap outputs of df into Variables (that don't require grad)
|
};
|
TORCH_API Gradient differentiate(std::shared_ptr<Graph>& graph);
|
|
// can we take a derivative of this node symbolically?
|
TORCH_API bool isDifferentiable(Node* n);
|
TORCH_API bool isDifferentiable(Graph& g);
|
TORCH_API bool isZero(Value* v);
|
|
} // namespace jit
|
} // namespace torch
|