#pragma once #include #include #include #include #include #include #include #include #include namespace torch { namespace serialize { class OutputArchive; class InputArchive; } // namespace serialize } // namespace torch namespace torch { namespace optim { struct TORCH_API RMSpropOptions { RMSpropOptions(double learning_rate); TORCH_ARG(double, learning_rate); TORCH_ARG(double, alpha) = 0.99; TORCH_ARG(double, eps) = 1e-8; TORCH_ARG(double, weight_decay) = 0; TORCH_ARG(double, momentum) = 0; TORCH_ARG(bool, centered) = false; }; class TORCH_API RMSprop : public Optimizer { public: template explicit RMSprop( ParameterContainer&& parameters, const RMSpropOptions& options_) : Optimizer(std::forward(parameters)), options(options_) {} void step() override; RMSpropOptions options; void save(serialize::OutputArchive& archive) const override; void load(serialize::InputArchive& archive) override; std::vector square_average_buffers; std::vector momentum_buffers; std::vector grad_average_buffers; private: RMSprop() : options(0) {} template static void serialize(Self& self, Archive& archive) { _TORCH_OPTIM_SERIALIZE(square_average_buffers); _TORCH_OPTIM_SERIALIZE(momentum_buffers); _TORCH_OPTIM_SERIALIZE(grad_average_buffers); } }; } // namespace optim } // namespace torch