#pragma once #include #include #include #include #include #include #include // Forward declarations confuse Doxygen #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace at { class Tensor; } // namespace at namespace torch { using at::Tensor; namespace serialize { class OutputArchive; class InputArchive; } // namespace serialize } // namespace torch #endif // DOXYGEN_SHOULD_SKIP_THIS namespace torch { namespace optim { namespace detail { /// Base class for all optimizers, that does not yet define a `step()` /// mechanism. All it specifies is that optimizers must be supplied with a /// vector of parameters. It also defines certain methods that all optimizers /// shall have, such as `zero_grad`. class TORCH_API OptimizerBase { public: /// Constructs the `Optimizer` from a vector of parameters. explicit OptimizerBase(std::vector parameters); virtual ~OptimizerBase() = default; /// Adds the given vector of parameters to the optimizer's parameter list. void add_parameters(const std::vector& parameters); /// Zeros out the gradients of all parameters. virtual void zero_grad(); /// Provides a const reference to the parameters this optimizer holds. const std::vector& parameters() const noexcept; /// Provides a reference to the parameters this optimizer holds. std::vector& parameters() noexcept; /// Returns the number of parameters referenced by the optimizer. size_t size() const noexcept; /// Serializes the optimizer state into the given `archive`. virtual void save(serialize::OutputArchive& archive) const; /// Deserializes the optimizer state from the given `archive`. virtual void load(serialize::InputArchive& archive); protected: OptimizerBase() = default; /// Accesses a buffer at the given index. /// Additionally, zeros out the buffers when this is called on the index template T& buffer_at(std::vector& buffers, size_t index) { if (buffers.size() <= index) { const auto old_size = buffers.size(); buffers.resize(index + 1); std::fill(buffers.begin() + old_size, buffers.end(), T{0}); } return buffers[index]; } /// Accesses a buffer at the given index, converts it to the type of the /// parameter at the corresponding index (a no-op if they match). /// Additionally, zeros out the buffers when this is called on the index Tensor& buffer_at(std::vector& buffers, size_t index); /// The parameters this optimizer optimizes. std::vector parameters_; }; /// Serializes an `OptimizerBase` into an `OutputArchive`. TORCH_API serialize::OutputArchive& operator<<( serialize::OutputArchive& archive, const OptimizerBase& optimizer); /// Deserializes a `Tensor` from an `InputArchive`. TORCH_API serialize::InputArchive& operator>>( serialize::InputArchive& archive, OptimizerBase& optimizer); } // namespace detail /// Optimizer that defines a required `step()` method that takes no arguments /// and produces no values. The only side effect is that parameters are updated /// according to the concrete optimization algorithm. class Optimizer : public detail::OptimizerBase { public: using detail::OptimizerBase::OptimizerBase; virtual void step() = 0; }; /// Optimizer that requires the loss function to be supplied to the `step()` /// function, as it may evaluate the loss function multiple times per step. /// Examples of such algorithms are conjugate gradient and LBFGS. The `step()` /// function also returns the loss value. class LossClosureOptimizer : public detail::OptimizerBase { public: /// A loss function closure, which is expected to return the loss value. using LossClosure = std::function; using detail::OptimizerBase::OptimizerBase; virtual Tensor step(LossClosure closure) = 0; }; } // namespace optim } // namespace torch