#pragma once #include #include #include #include #include #include #include namespace torch { namespace data { /// A dataloader for stateless datasets. /// /// This dataloader follows the traditional PyTorch dataloader design, whereby a /// (posssibly) stateful sampler produces *batch requests* for a stateless /// dataset, which acts as a simple batch request to batch mapping. The batch /// request will often be an array of indices, and if the dataset is a simple /// image dataset, the dataset would produce the images at those indices. template class StatelessDataLoader : public DataLoaderBase< Dataset, typename Dataset::BatchType, typename Sampler::BatchRequestType> { public: using super = DataLoaderBase< Dataset, typename Dataset::BatchType, typename Sampler::BatchRequestType>; using typename super::BatchRequestType; /// Constructs the `StatelessDataLoader` from a `dataset`, a `sampler` and /// some `options`. StatelessDataLoader( Dataset dataset, Sampler sampler, DataLoaderOptions options) : super(std::move(options)), sampler_(std::move(sampler)) { for (size_t w = 0; w < this->options_.workers; ++w) { // Here we copy the dataset into the worker thread closure. Each worker // has its own copy of the dataset. This means the dataset must be // trivially copiable, or else we don't expect more than one worker to // be in use. this->workers_.emplace_back( [this, dataset]() mutable { this->worker_thread(dataset); }); } if (this->options_.workers == 0) { this->main_thread_dataset_ = torch::make_unique(std::move(dataset)); } } private: /// Resets the internal state of the dataloader and the sampler. void reset() override { sampler_.reset(); // Call the base class method last because it calls `prefetch()` super::reset(); } /// Queries the sampler for the next batch request (possibly progressing its /// internal state). optional get_batch_request() override { auto indices = sampler_.next(this->options_.batch_size); if (!indices || (indices->size() < this->options_.batch_size && this->options_.drop_last)) { return nullopt; } AT_ASSERT(indices->size() > 0); return indices; } /// The `Sampler` used to produce batch requests. Sampler sampler_; }; } // namespace data } // namespace torch