#pragma once
|
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
|
#include <algorithm>
|
#include <functional>
|
#include <iterator>
|
#include <memory>
|
#include <string>
|
#include <vector>
|
|
// Forward declarations confuse Doxygen
|
#ifndef DOXYGEN_SHOULD_SKIP_THIS
|
namespace at {
|
class Tensor;
|
} // namespace at
|
|
namespace torch {
|
using at::Tensor;
|
namespace serialize {
|
class OutputArchive;
|
class InputArchive;
|
} // namespace serialize
|
} // namespace torch
|
#endif // DOXYGEN_SHOULD_SKIP_THIS
|
|
namespace torch {
|
namespace optim {
|
namespace detail {
|
/// Base class for all optimizers, that does not yet define a `step()`
|
/// mechanism. All it specifies is that optimizers must be supplied with a
|
/// vector of parameters. It also defines certain methods that all optimizers
|
/// shall have, such as `zero_grad`.
|
class TORCH_API OptimizerBase {
|
public:
|
/// Constructs the `Optimizer` from a vector of parameters.
|
explicit OptimizerBase(std::vector<Tensor> parameters);
|
|
virtual ~OptimizerBase() = default;
|
|
/// Adds the given vector of parameters to the optimizer's parameter list.
|
void add_parameters(const std::vector<Tensor>& parameters);
|
|
/// Zeros out the gradients of all parameters.
|
virtual void zero_grad();
|
|
/// Provides a const reference to the parameters this optimizer holds.
|
const std::vector<Tensor>& parameters() const noexcept;
|
|
/// Provides a reference to the parameters this optimizer holds.
|
std::vector<Tensor>& parameters() noexcept;
|
|
/// Returns the number of parameters referenced by the optimizer.
|
size_t size() const noexcept;
|
|
/// Serializes the optimizer state into the given `archive`.
|
virtual void save(serialize::OutputArchive& archive) const;
|
|
/// Deserializes the optimizer state from the given `archive`.
|
virtual void load(serialize::InputArchive& archive);
|
|
protected:
|
OptimizerBase() = default;
|
|
/// Accesses a buffer at the given index.
|
/// Additionally, zeros out the buffers when this is called on the index
|
template <typename T>
|
T& buffer_at(std::vector<T>& buffers, size_t index) {
|
if (buffers.size() <= index) {
|
const auto old_size = buffers.size();
|
buffers.resize(index + 1);
|
std::fill(buffers.begin() + old_size, buffers.end(), T{0});
|
}
|
return buffers[index];
|
}
|
|
/// Accesses a buffer at the given index, converts it to the type of the
|
/// parameter at the corresponding index (a no-op if they match).
|
/// Additionally, zeros out the buffers when this is called on the index
|
Tensor& buffer_at(std::vector<Tensor>& buffers, size_t index);
|
|
/// The parameters this optimizer optimizes.
|
std::vector<Tensor> parameters_;
|
};
|
|
/// Serializes an `OptimizerBase` into an `OutputArchive`.
|
TORCH_API serialize::OutputArchive& operator<<(
|
serialize::OutputArchive& archive,
|
const OptimizerBase& optimizer);
|
|
/// Deserializes a `Tensor` from an `InputArchive`.
|
TORCH_API serialize::InputArchive& operator>>(
|
serialize::InputArchive& archive,
|
OptimizerBase& optimizer);
|
} // namespace detail
|
|
/// Optimizer that defines a required `step()` method that takes no arguments
|
/// and produces no values. The only side effect is that parameters are updated
|
/// according to the concrete optimization algorithm.
|
class Optimizer : public detail::OptimizerBase {
|
public:
|
using detail::OptimizerBase::OptimizerBase;
|
virtual void step() = 0;
|
};
|
|
/// Optimizer that requires the loss function to be supplied to the `step()`
|
/// function, as it may evaluate the loss function multiple times per step.
|
/// Examples of such algorithms are conjugate gradient and LBFGS. The `step()`
|
/// function also returns the loss value.
|
class LossClosureOptimizer : public detail::OptimizerBase {
|
public:
|
/// A loss function closure, which is expected to return the loss value.
|
using LossClosure = std::function<Tensor()>;
|
using detail::OptimizerBase::OptimizerBase;
|
virtual Tensor step(LossClosure closure) = 0;
|
};
|
|
} // namespace optim
|
} // namespace torch
|