#pragma once #include #include #include #include #include #include #include namespace torch { namespace nn { namespace detail { template class DropoutImplBase : public torch::nn::Cloneable { public: explicit DropoutImplBase(const DropoutOptions& options_ = DropoutOptions()); void reset() override; /// The options used to configure this `Dropout` module. DropoutOptions options; }; } // namespace detail /// Applies [Dropout](https://arxiv.org/abs/1207.0580) during training. /// /// See https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout to learn more /// about the exact semantics of this module. class TORCH_API DropoutImpl : public detail::DropoutImplBase { public: explicit DropoutImpl(const DropoutOptions& options_ = DropoutOptions()); /// During training, applies a noise mask to the input tensor. /// During evaluation, applies an identity function. Tensor forward(const Tensor& input); /// Pretty prints the `Dropout` module into the given `stream`. void pretty_print(std::ostream& stream) const override; }; /// Applies spatial [Dropout](https://arxiv.org/abs/1207.0580) to inputs with /// 2-D or 3-D features. /// /// The equivalent in Python is /// [Dropout2d](https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout2d) for /// 2-D features and /// [Dropout3d](https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout3d) for /// 3-D features. This `FeatureDropout` module can instead deal with both 2-D /// and 3-D features. class TORCH_API FeatureDropoutImpl : public detail::DropoutImplBase { public: explicit FeatureDropoutImpl(const DropoutOptions& options_ = DropoutOptions()); /// During training, applies a noise mask to the input tensor. /// During evaluation, applies an identity function. Tensor forward(const Tensor& input); /// Pretty prints the `FeatureDropout` module into the given `stream`. void pretty_print(std::ostream& stream) const override; }; /// A `ModuleHolder` subclass for `DropoutImpl`. /// See the documentation for `DropoutImpl` class to learn what methods it /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's /// module storage semantics. TORCH_MODULE(Dropout); /// A `ModuleHolder` subclass for `FeatureDropoutImpl`. /// See the documentation for `FeatureDropoutImpl` class to learn what methods /// it provides, or the documentation for `ModuleHolder` to learn about /// PyTorch's module storage semantics. TORCH_MODULE(FeatureDropout); } // namespace nn } // namespace torch