#pragma once #include #include #include #include #include namespace torch { namespace nn { /// Applies fold over a 3-D input. /// See https://pytorch.org/docs/master/nn.html#torch.nn.Fold to learn about /// the exact behavior of this module. class TORCH_API FoldImpl : public torch::nn::Cloneable { public: FoldImpl(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size) : FoldImpl(FoldOptions(output_size, kernel_size)) {} explicit FoldImpl(const FoldOptions& options_); void reset() override {} /// Pretty prints the `Fold` module into the given `stream`. void pretty_print(std::ostream& stream) const override { stream << "torch::nn::Fold"; } Tensor forward(const Tensor& input); /// The options with which this `Module` was constructed. FoldOptions options; }; /// A `ModuleHolder` subclass for `FoldImpl`. /// See the documentation for `FoldImpl` class to learn what methods it /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's /// module storage semantics. TORCH_MODULE(Fold); } // namespace nn } // namespace torch