1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
| #pragma once
|
| #include <torch/csrc/jit/ir.h>
|
| namespace torch {
| namespace jit {
|
| // This pass removes 'grad_of' nodes, replacing them with conditionals of
| // the form:
| // if any_defined(inputs):
| // outputs = <original_computation>
| // else:
| // outputs = undefineds
| TORCH_API void LowerGradOf(Graph& g);
|
| } // namespace jit
| } // namespace torch
|
|