#pragma once #include #include #include #include #include namespace torch { namespace autograd { TORCH_API variable_list _wrap_outputs( const variable_list &input_vars, const std::unordered_set &non_differentiable, const std::unordered_set &dirty_inputs, const at::ArrayRef raw_outputs, const std::shared_ptr &cdata); TORCH_API void check_variable_result(const Variable& original, const Variable& result, std::string hook_name); // Get the return type of the forward function of the custom Function class X template using forward_t = decltype(X::forward(nullptr, std::declval()...)); // To use custom autograd operations implement a Function subclass with // static backward and forward functions // // forward() can take as many arguments as you want and should return either a // variable list or a Variable. Use of any direct Variable arguments will be // registered in the graph but no vectors/sets or any other data structures will // be traversed. It should take an AutogradContext* as the first argument. // Variables can be saved in the ctx using save_for_backward() and other data // can be saved in the map ctx.save in the form of // pairs. // // backward() should take an AutogradContext* and a variable list containing as // many Variables as there were outputs from forward as arguments. It should // return as many Variables as there were inputs with each of them containing // the gradient w.r.t. its corresponding input. Variables saved in forward can // be accessed with ctx->get_saved_variables() and other saved data can be // accessed from ctx->saved_data. // // For example: // class MyFunction : public Function { // public: // static variable_list forward(AutogradContext *ctx, int n, Variable var) { // // Save data for backward in context // ctx->saved_data["n"] = n; // var.mul_(2); // // Mark var as modified by inplace operation // ctx->mark_dirty({var}); // return {var}; // } // // static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // // Use data saved in forward // auto n = ctx->saved_data["n"].toInt(); // return {grad_output[0]*n}; // } // }; // // To use MyFunction // Variable x; // auto y = MyFunction::apply(6, x); // Example backward call: // y[0].sum().backward(); template struct TORCH_API Function { // We need to use a different template parameter than T here because T will // inherit from Function, and when Function is instantiated, T::forward // is not declared yet. // The enable_if check is to ensure that the user doesn't explicitly provide // the parameter X. template static auto apply(Args&&... args) -> c10::guts::enable_if_t::value, forward_t>; }; // Context to save information during forward that can be accessed in backward struct TORCH_API AutogradContext { AutogradContext() = default; AutogradContext(const AutogradContext &other) = delete; AutogradContext& operator=(const AutogradContext& other) = delete; // Can be used to save non-variable data for backward() ska::flat_hash_map saved_data; // Saves the list of variables for a future call to backward(). This // should be called at most once from inside of forward(). void save_for_backward(variable_list to_save); // Marks variables in the list as modified in an in-place operation. This // should be called at most once from inside of forward() and all arguments // should be inputs. void mark_dirty(const variable_list &inputs); // Marks outputs in the list as not requiring gradients. This should be called // at most once from inside of forward() and all arguments should be outputs. void mark_non_differentiable(const variable_list &outputs); // Get the list of variables that were saved in forward using // save_for_backward(). Before returning them to the user, a check is made to // ensure that they were not modified by any in-place operations. variable_list get_saved_variables() const; const std::unordered_set& get_dirty() const; const std::unordered_set& get_non_differentiable() const; private: std::unordered_set non_differentiable_; std::unordered_set dirty_inputs_; std::vector saved_variables_; variable_list to_save_; // The CppNode in the autograd graph that owns this AutogradContext. We need a // weak_ptr to avoid a refcycle. Since grad_fn_ owns this AutogradContext, it // will always be alive when we want to use it. std::weak_ptr grad_fn_; bool has_freed_buffers_; void save_variables(); template friend struct CppNode; }; struct TORCH_API VariableInfo { explicit VariableInfo(const Variable& var); Variable zeros(at::OptionalDeviceGuard& device_guard) const; at::Layout layout = at::Layout::Strided; at::Device device = at::kCPU; at::ScalarType scalar_type = at::kFloat; std::vector size; bool requires_grad; }; // CppNode is the Node in the autograd graph that represents the user defined // backward function for Function. Calls to CppNode::apply are forward to // T::backward(). template struct CppNode : public Node { variable_list apply(variable_list&& inputs) override; AutogradContext ctx_; std::vector is_variable_input_; std::vector input_info_; std::vector output_info_; void release_variables() override; void set_ctx_grad_fn(const std::shared_ptr &node); void save_variables_to_ctx(); }; template using enable_if_var_t = typename std::enable_if::value>::type; template using enable_if_not_var_t = typename std::enable_if::value>::type; template enable_if_not_var_t extract_vars(std::vector &is_var, variable_list& list, T&& cur, Args&& ... args) { is_var.push_back(false); extract_vars(is_var, list, std::forward(args)...); } template enable_if_var_t extract_vars(std::vector &is_var, variable_list& list, T&& cur, Args&& ... args) { is_var.push_back(true); list.emplace_back(cur); extract_vars(is_var, list, std::forward(args)...); } template void extract_vars(std::vector &is_var, variable_list& list, Args&& ... args) { } template typename std::enable_if::value, T&>::type to_output_type(variable_list& output_list) { return output_list; } template typename std::enable_if::value, T>::type to_output_type(variable_list& output_list) { return output_list[0]; } template template auto Function::apply(Args&&... args) -> c10::guts::enable_if_t::value, forward_t> { std::shared_ptr> node(new CppNode(), deleteNode); variable_list input_vars; const size_t num_inputs = sizeof...(Args); input_vars.reserve(num_inputs); node->is_variable_input_.reserve(num_inputs); // TODO Add tracing here extract_vars(node->is_variable_input_, input_vars, args...); bool is_executable = GradMode::is_enabled() && any_variable_requires_grad(input_vars); auto next_edges = collect_next_edges(input_vars); node->set_ctx_grad_fn(node); node->set_next_edges(std::move(next_edges)); node->clear_input_metadata(); node->input_info_.reserve(input_vars.size()); for (auto& var : input_vars) { node->input_info_.emplace_back(var); } using forward_return_t = forward_t; forward_return_t outputs; { AutoGradMode grad_mode(false); outputs = T::forward(&node->ctx_, std::forward(args)...); } auto wrapped_outputs = _wrap_outputs(input_vars, node->ctx_.get_non_differentiable(), node->ctx_.get_dirty(), outputs, is_executable ? node : nullptr); node->output_info_.reserve(wrapped_outputs.size()); for (auto& output : wrapped_outputs) { if (is_executable) { node->output_info_.emplace_back(output); } } if (is_executable) { node->save_variables_to_ctx(); } // wrapped_outputs will be a variable_list so, convert it to the correct // return type. Only Variable and variable_list are accepted as return types. return to_output_type(wrapped_outputs); } // The logic here is the same as PyNode::apply, so changes to it should be done // in both the places template variable_list CppNode::apply(variable_list&& inputs) { at::OptionalDeviceGuard _device_guard; int num_inputs = inputs.size(); variable_list backward_inputs; backward_inputs.reserve(num_inputs); for (int i = 0 ; i < num_inputs; ++i) { if (inputs[i].defined()) { backward_inputs.emplace_back(inputs[i]); } else { backward_inputs.emplace_back(output_info_[i].zeros(_device_guard)); } } auto outputs = T::backward(&ctx_, backward_inputs); int num_forward_inputs = is_variable_input_.size(); int num_outputs = outputs.size(); // Returning too many results is ok, but only as long as they're all undefined. // Truncate the result vector in that case. if (num_outputs > num_forward_inputs) { bool all_undef = true; for (int i = num_forward_inputs; i < num_outputs; ++i) { all_undef &= (!outputs[i].defined()); } if (all_undef) { outputs.resize(num_forward_inputs); num_outputs = num_forward_inputs; } } if (num_outputs != num_forward_inputs) { std::string msg("function "); msg += name() + " returned an incorrect number of gradients (expected "; msg += std::to_string(num_forward_inputs) + ", got " ; msg += std::to_string(num_outputs) + ")"; throw std::runtime_error(msg); } variable_list results; results.reserve(num_outputs); for (int i = 0; i < num_outputs; ++i) { if (!is_variable_input_[i]) { if (outputs[i].defined()) { std::string msg("function "); msg += name() + " returned a gradient different that is defined at position "; msg += std::to_string(i + 1) + ", but the corresponding forward input was not a Variable"; throw std::runtime_error(msg); } continue; } if (!outputs[i].defined()) { auto& info = input_info_[results.size()]; if (info.requires_grad) { results.emplace_back(info.zeros(_device_guard)); } else { results.emplace_back(); } } else { results.emplace_back(outputs[i]); } } return results; } template void CppNode::release_variables() { ctx_.saved_variables_.clear(); ctx_.has_freed_buffers_ = true; } template void CppNode::save_variables_to_ctx() { ctx_.save_variables(); } template void CppNode::set_ctx_grad_fn(const std::shared_ptr &node) { ctx_.grad_fn_ = node; } }} // namespace torch::autograd