#pragma once #include #include #include namespace torch { namespace jit { // insert GraphExecutor nodes that group together // subgraphs that are differentiable by the jit's autodiff passes // threshold - minimum number of nodes that will appear in a block // returns all differentiable blocks that have been found TORCH_API std::vector CreateAutodiffSubgraphs( const std::shared_ptr& graph, size_t threshold = 2); } // namespace jit } // namespace torch