#pragma once #include #include #include #include namespace torch { namespace data { /// A dataloader for stateful datasets. /// /// A dataloader for stateful datatasets differs from one for stateless /// datasets one in that the dataset is shared among worker threads, and that /// this dataset is itself responsible for producing batches rather than /// depending on a sampler. The statefulness here actually refers to the /// dataset. The StatefulDataLoader simply alters the data loading algorithm to /// accomodate the stateful, shared nature of the dataset. Note that the dataset /// must be thread safe if more than one worker thread is used. /// /// A stateful dataloader is created by calling `make_data_loader` with a /// stateful dataset. template class StatefulDataLoader : public DataLoaderBase< Dataset, typename Dataset::BatchType::value_type, typename Dataset::BatchRequestType> { public: using super = DataLoaderBase< Dataset, typename Dataset::BatchType::value_type, typename Dataset::BatchRequestType>; using typename super::BatchRequestType; /// Constructs the `StatefulDataLoader` from a `dataset` and some `options`. StatefulDataLoader(Dataset dataset, DataLoaderOptions options) : super( std::move(options), torch::make_unique(std::move(dataset))) { for (size_t w = 0; w < this->options_.workers; ++w) { // As opposed to the stateless case, here all worker threads access the // same underlying dataset. this->workers_.emplace_back( [this] { this->worker_thread(*this->main_thread_dataset_); }); } } private: /// Resets the internal state of the dataloader and the dataset. void reset() override { this->main_thread_dataset_->reset(); // Call the base class method last because it calls `prefetch()` super::reset(); } /// For stateful datasets, the batch request is always the batch size. The /// dataset is responsible for determining what goes into the batch next. optional get_batch_request() override { return this->options_.batch_size; } }; } // namespace data } // namespace torch