#pragma once #include #include #include #include #include #include namespace torch { namespace serialize { class OutputArchive; class InputArchive; } // namespace serialize } // namespace torch namespace torch { namespace optim { struct TORCH_API AdamOptions { /* implicit */ AdamOptions(double learning_rate); TORCH_ARG(double, learning_rate); TORCH_ARG(double, beta1) = 0.9; TORCH_ARG(double, beta2) = 0.999; TORCH_ARG(double, weight_decay) = 0; TORCH_ARG(double, eps) = 1e-8; TORCH_ARG(bool, amsgrad) = false; }; class TORCH_API Adam : public Optimizer { public: template explicit Adam(ParameterContainer&& parameters, const AdamOptions& options_) : Optimizer(std::forward(parameters)), options(options_) {} void step() override; void save(serialize::OutputArchive& archive) const override; void load(serialize::InputArchive& archive) override; AdamOptions options; std::vector step_buffers; std::vector exp_average_buffers; std::vector exp_average_sq_buffers; std::vector max_exp_average_sq_buffers; private: Adam() : options(0) {} template static void serialize(Self& self, Archive& archive) { _TORCH_OPTIM_SERIALIZE(step_buffers); _TORCH_OPTIM_SERIALIZE(exp_average_buffers); _TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers); _TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers); } }; } // namespace optim } // namespace torch