#pragma once #include #include #include namespace torch { namespace nn { namespace detail { /// Common options for LSTM and GRU modules. struct TORCH_API RNNOptionsBase { RNNOptionsBase(int64_t input_size, int64_t hidden_size); virtual ~RNNOptionsBase() = default; /// The number of features of a single sample in the input sequence `x`. TORCH_ARG(int64_t, input_size); /// The number of features in the hidden state `h`. TORCH_ARG(int64_t, hidden_size); /// The number of recurrent layers (cells) to use. TORCH_ARG(int64_t, layers) = 1; /// Whether a bias term should be added to all linear operations. TORCH_ARG(bool, with_bias) = true; /// If non-zero, adds dropout with the given probability to the output of each /// RNN layer, except the final layer. TORCH_ARG(double, dropout) = 0.0; /// Whether to make the RNN bidirectional. TORCH_ARG(bool, bidirectional) = false; /// If true, the input sequence should be provided as `(batch, sequence, /// features)`. If false (default), the expected layout is `(sequence, batch, /// features)`. TORCH_ARG(bool, batch_first) = false; }; } // namespace detail enum class RNNActivation : uint32_t {ReLU, Tanh}; /// Options for RNN modules. struct TORCH_API RNNOptions { RNNOptions(int64_t input_size, int64_t hidden_size); /// Sets the activation after linear operations to `tanh`. RNNOptions& tanh(); /// Sets the activation after linear operations to `relu`. RNNOptions& relu(); /// The number of features of a single sample in the input sequence `x`. TORCH_ARG(int64_t, input_size); /// The number of features in the hidden state `h`. TORCH_ARG(int64_t, hidden_size); /// The number of recurrent layers (cells) to use. TORCH_ARG(int64_t, layers) = 1; /// Whether a bias term should be added to all linear operations. TORCH_ARG(bool, with_bias) = true; /// If non-zero, adds dropout with the given probability to the output of each /// RNN layer, except the final layer. TORCH_ARG(double, dropout) = 0.0; /// Whether to make the RNN bidirectional. TORCH_ARG(bool, bidirectional) = false; /// If true, the input sequence should be provided as `(batch, sequence, /// features)`. If false (default), the expected layout is `(sequence, batch, /// features)`. TORCH_ARG(bool, batch_first) = false; /// The activation to use after linear operations. TORCH_ARG(RNNActivation, activation) = RNNActivation::ReLU; }; using LSTMOptions = detail::RNNOptionsBase; using GRUOptions = detail::RNNOptionsBase; } // namespace nn } // namespace torch