#pragma once #include #include #include #include #include #include #include #include #include namespace torch { namespace data { /// Creates a `DataLoader` instance for a stateless `dataset`, a `sampler` and /// some `options`. template torch::disable_if_t< Dataset::is_stateful, std::unique_ptr>> make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) { return torch::make_unique>( std::move(dataset), std::move(sampler), std::move(options)); } /// Creates a `DataLoader` instance for a stateless `dataset` and some /// `options`. A sampler (by default a `RandomSampler`) will be constructed from /// the size of the dataset. template torch::disable_if_t< Dataset::is_stateful || !std::is_constructible::value, std::unique_ptr>> make_data_loader( Dataset dataset, DataLoaderOptions options = DataLoaderOptions()) { const optional size = dataset.size(); TORCH_CHECK( size.has_value(), "Expected the dataset to be sized in " "order to construct the Sampler"); return make_data_loader( std::move(dataset), Sampler(*size), std::move(options)); } /// Creates a `DataLoader` for a stateful `dataset` and some `options`. template > std::unique_ptr> make_data_loader( Dataset dataset, DataLoaderOptions options = DataLoaderOptions()) { return torch::make_unique>( std::move(dataset), std::move(options)); } } // namespace data } // namespace torch