#pragma once #include #include #include #include #include #include #include namespace torch { namespace data { namespace datasets { /// Interface for chunk reader, which performs data chunking and reading of /// entire chunks. /// /// A chunk could be an entire file, such as an audio data file or an image, /// or part of a file in the case of a large text-file split based on seek /// positions. template > class ChunkDataReader { public: using ChunkType = ChunkType_; using ExampleType = ExampleType_; /// Read an entire chunk. virtual ChunkType read_chunk(size_t chunk_index) = 0; /// Returns the number of chunks available in this reader. virtual size_t chunk_count() = 0; /// This will clear any internal state associate with this reader. virtual void reset() = 0; }; namespace detail { /// BatchDataBuffer manages a queue of UnwrappedBatchData. After a new chunk is /// loaded, BatchDataBuffer splits it into small batches and push them into the /// queue. When get_batch is called from data loader, it pops cached batches and /// return. If the cache is empty, it either waits to load more chunks or return /// null if all chunks are loaded. template < typename UnwrappedBatch, typename ExampleSampler = samplers::RandomSampler> class BatchDataBuffer { public: using UnwrappedBatchType = UnwrappedBatch; using BatchType = torch::optional; using BatchRequestType = typename ExampleSampler::BatchRequestType; BatchDataBuffer( size_t batch_size, ExampleSampler& example_sampler, size_t queue_capacity) : batch_size_(batch_size), example_sampler_(example_sampler), queue_capacity_(queue_capacity) {} /// Return batch data from the queue. Called from the ChunkDataset main /// thread. BatchType get_batch() { std::unique_lock lock(queue_mutex_); cv_read_.wait(lock, [this] { // wait till there is available data in the queue or if all chunks are // loaded (i.e. the dataset is exhausted for this epoch) return ( this->total_example_count_in_queue_ >= batch_size_ || this->stop_); }); if (batch_queue_.empty()) { AT_ASSERT(stop_); // All batches have been retrieved. Return an empty batch. return nullopt; } UnwrappedBatchData batch = std::move(batch_queue_.front()); batch_queue_.pop(); if (batch.exception) { throw WorkerException(batch.exception); } total_example_count_in_queue_ -= batch.batch_data.size(); lock.unlock(); cv_write_.notify_all(); return batch.batch_data; } /// Push preloaded chunks to batch queue. Called from the ChunkDataset worker /// threads. void add_chunk_data(UnwrappedBatchType data) { std::unique_lock lock(queue_mutex_); cv_write_.wait(lock, [this] { // stop loading if we have preloaded enough data. return this->total_example_count_in_queue_ < this->queue_capacity_ || this->stop_; }); if (stop_) { // When stop_ is true, it means no further chunk loading is necessary. // Return without any further processing. return; } auto data_size = data.size(); auto remaining_size = data_size; example_sampler_.reset(data_size); auto fill_batch = [&](size_t example_count, UnwrappedBatchType& batch) { auto batch_example_indices = this->example_sampler_.next(example_count); AT_ASSERT( batch_example_indices && batch_example_indices.value().size() == example_count); BatchRequestType& indices = batch_example_indices.value(); for (size_t i : indices) { TORCH_CHECK(i < data_size, "Index out of range"); batch.emplace_back(std::move(data[i])); } remaining_size -= example_count; }; if (!batch_queue_.empty()) { // if the queue has existing data, and the last batch doesn't have enough // examples to fill a batch_size batch, add more example to this batch first. auto& batch = batch_queue_.back(); size_t current_count = batch.batch_data.size(); if (current_count < batch_size_) { auto example_count = std::min(remaining_size, batch_size_ - current_count); fill_batch(example_count, batch.batch_data); } } // If we still have data remaining after filling the last pushed batch, add // them to the queue too. while (remaining_size > 0) { UnwrappedBatchType current_batch; // Allocate the batch memory ahead of time. current_batch.reserve(batch_size_); auto example_count = std::min(remaining_size, batch_size_); fill_batch(example_count, current_batch); batch_queue_.emplace(std::move(current_batch)); } total_example_count_in_queue_ += data_size; lock.unlock(); cv_read_.notify_all(); } /// Push exceptions thrown during preloading into batch queue. Called from /// the ChunkDataset worker threads. void add_chunk_data(std::exception_ptr e_ptr) { std::unique_lock lock(queue_mutex_); cv_write_.wait(lock, [this] { // stop loading if we have preloaded enough data. return ( this->total_example_count_in_queue_ < this->queue_capacity_ || this->stop_); }); if (stop_){ // When stop_ is true, it means this current thread needs to be tore down, // the batch buffer will be discarded, so no need to enqueue any new // exceptions. return; } batch_queue_.emplace(e_ptr); lock.unlock(); cv_read_.notify_all(); } void stop(){ { // Hold the lock before changing stop_ to prevent a race condition which can // cause a deadlock. // To be more specific, conditional variable cv_write_ waits on predicate // stop_ in add_chunk_data(). The wait happens in two steps: 1) while still // holding the lock, check if predicate is true; 2) if it is true, proceeds, // otherwise, release the lock and wait until notified. Without holding a // lock, cv_write_'s notification can happen in between step 1) and 2). In // that case, as cv_write_ is not in waiting status yet, so the notification // is lost and cv_write_ will sleep forever. // By taking a lock before changing predicate stop_, it is ensured updating // and evaluating stop_ always happen in a synchronized way std::lock_guard lock(queue_mutex_); stop_ = true; } // notify all writers, wake them from wait to exit current method. cv_write_.notify_all(); // notify all readers too. cv_read_.notify_all(); } /// The batch size is needed to create batches from the chunk data. Similar to /// regular dataloader where the batches are created with prefetches, /// BatchDataBuffer perform the batch creation using the provided batch size. size_t batch_size_ = 0; /// count of total example stored in the queue size_t total_example_count_in_queue_ = 0; /// struct that contains a raw unwrapped batch unit. An unwrapped batch unit is /// the raw data without 'optional' wrapper. It can be a collection of images, /// utterances, e.t.c. struct UnwrappedBatchData { explicit UnwrappedBatchData(UnwrappedBatchType data) : batch_data(std::move(data)) {} explicit UnwrappedBatchData(std::exception_ptr e) : exception(e) {} /// batch data to return UnwrappedBatchType batch_data; /// exception pointer which captures any abnormal exceptions while creating the /// batch. std::exception_ptr exception; }; /// local cache to store example batches from loaded chunk std::queue batch_queue_; // sync batch_queue_ update. std::mutex queue_mutex_; std::condition_variable cv_read_; std::condition_variable cv_write_; ExampleSampler& example_sampler_; // configurable maximun number of elements the queue can hold at one time. size_t queue_capacity_; // When set to true, it wakes the writer threads from the wait and exit current // function call. This is needed when ChunkDataSet.Reset is called while the // previous epoch is not exhausted yet. When ChunkDataset is waiting its // preloader to finish previous work before tearing down the thread, the // preloader could be still waiting for the conditional variable, thus cause // the program to hang. This boolean is used to break this waiting condition. bool stop_ = false; }; } // namespace detail /// Options to configure a `ChunkDataset`. struct ChunkDatasetOptions { ChunkDatasetOptions() = delete; ChunkDatasetOptions( size_t preloader_count, size_t batch_size, size_t cache_size = 2048, size_t cross_chunk_shuffle_count = 1) : preloader_count_(preloader_count), batch_size_(batch_size), cache_size_(cache_size), cross_chunk_shuffle_count_(cross_chunk_shuffle_count) { TORCH_CHECK( preloader_count_ > 0, "Preloader count is 0. At least one preloader needs to be specified."); TORCH_CHECK( batch_size_ > 0, "Batch size is 0. A positive batch size needs to be specified."); TORCH_CHECK( cache_size_ > 0, "Cache size is 0. A positive cache size needs to be specified."); TORCH_CHECK( cache_size_ >= batch_size_, "Cache size is less than batch size. Cache needs to be large enough to " "hold at least one batch."); TORCH_CHECK( cross_chunk_shuffle_count_ > 0, "cross_chunk_shuffle_count needs to be greater than 0."); } /// The number of worker thread to preload chunk data. TORCH_ARG(size_t, preloader_count); /// The size of each batch. TORCH_ARG(size_t, batch_size); /// The capacity of the queue for batch caching. TORCH_ARG(size_t, cache_size) = 2048; // The number of chunks to perfrom cross-chunk shuffling. Default to 1 meaning // no cross-chunk shuffling. When it is equal to n (n > 1), n random // chunks will be loaded at once and example shuffling will be performed // across all those n chunks. // Note: Usually the default config (1 chunk shuffle + example shuffle) is // good enough to generate random distributed data. Use this parameter only if // you know cross-shuffle is needed in your case. Also there is a performance // penalty when this value is greater than 1, as we need to do extra merge // between multiple chunks before performing example sampling. TORCH_ARG(size_t, cross_chunk_shuffle_count) = 1; }; /// A stateful dataset that support hierarchical sampling and prefetching of /// entre chunks. /// /// Unlike regular dataset, chunk dataset require two samplers to operate and /// keeps an internal state. `ChunkSampler` selects, which chunk to load next, /// while the `ExampleSampler` determins the order of Examples that are returned /// in each `get_batch` call. The hierarchical sampling approach used here is /// inspired by this paper http://martin.zinkevich.org/publications/nips2010.pdf template < typename ChunkReader, typename ChunkSampler = samplers::RandomSampler, typename ExampleSampler = samplers::RandomSampler> class ChunkDataset final : public StatefulDataset< ChunkDataset, typename ChunkReader::BatchType, size_t> { public: using BatchType = torch::optional; using UnwrappedBatchType = typename ChunkReader::BatchType; using BatchRequestType = size_t; using ChunkSamplerType = ChunkSampler; using ExampleSamplerType = ExampleSampler; ChunkDataset( ChunkReader chunk_reader, ChunkSampler chunk_sampler, ExampleSampler example_sampler, ChunkDatasetOptions options, std::function preprocessing_policy = std::function()) : chunk_reader_(std::move(chunk_reader)), chunk_sampler_(std::move(chunk_sampler)), example_sampler_(std::move(example_sampler)), options_(std::move(options)), preprocessing_policy_(preprocessing_policy), quit_worker_(false), running_preloaders_(0), load_checkpoint_(false) {} virtual ~ChunkDataset() { // stop batch buffer first. if (batch_buffer_) { batch_buffer_->stop(); } free_workers(); } /// Default get_batch method of BatchDataset. This method returns /// Example batches created from the preloaded chunks. The implemenation /// is dataset agnostic and does not need overriding in different chunk /// datasets. BatchType get_batch(size_t batch_size) override { TORCH_CHECK( batch_buffer_ != nullptr, "Dataset needs to call reset() before calling get_batch()."); TORCH_CHECK( batch_size == options_.batch_size(), "The requested batch size does not match with the initialized batch size.\n" " The requested batch size is ", batch_size, ", while the dataset is created with batch size equal to ", options_.batch_size()); return batch_buffer_->get_batch(); } /// Helper method around get_batch as `batch_size` is not strictly necessary BatchType get_batch() { return get_batch(options_.batch_size()); } /// This will clear any internal state and starts the internal prefetching /// mechanism for the chunk dataset. void reset() override { // We need this to support partial data reads via dataloader iterator. if (batch_buffer_) { batch_buffer_->stop(); } // free workers from previous reset if there is any. free_workers(); preload_threads_.clear(); if (!load_checkpoint_){ chunk_reader_.reset(); chunk_sampler_.reset(chunk_reader_.chunk_count()); load_checkpoint_ = false; } // Throw out any existing cached batch in the buffer and re-creates a new // chunk buffer. batch_buffer_ = torch::make_unique< detail::BatchDataBuffer>( options_.batch_size(), example_sampler_, options_.cache_size()); // create new workers for this new epoch. quit_worker_ = false; AT_ASSERT(running_preloaders_ == 0); running_preloaders_ = options_.preloader_count(); for (size_t i = 0; i < options_.preloader_count(); ++i) { preload_threads_.emplace_back([this, i]() { this->preloader(i); }); } } /// size is not used for chunk dataset. optional size() const override { return torch::nullopt; } // provide a references to chunk sampler. Used mainly in distributed data // loading to set the epoch number for the sampler. ChunkSamplerType& chunk_sampler() { return chunk_sampler_; } void save(serialize::OutputArchive& archive) const override { std::lock_guard lock(chunk_index_guard_); chunk_sampler_.save(archive); } void load(serialize::InputArchive& archive) override{ std::lock_guard lock(chunk_index_guard_); chunk_sampler_.load(archive); load_checkpoint_ = true; } private: /// running on worker thread to preload chunk data. void preloader(size_t id) { while (!quit_worker_.load()) { try { std::vector chunk_idx; { std::lock_guard lock(chunk_index_guard_); if (auto chunk_sampler_result = chunk_sampler_.next(this->options_.cross_chunk_shuffle_count())) { chunk_idx = chunk_sampler_result.value(); } else { break; } } UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_idx[0]); for (size_t i = 1; i < chunk_idx.size(); ++i) { auto chunk_data = chunk_reader_.read_chunk(chunk_idx[i]); std::move( chunk_data.begin(), chunk_data.end(), std::back_inserter(data)); } if (preprocessing_policy_) { preprocessing_policy_(data); } if (!data.empty()) { // skip empty chunks. batch_buffer_->add_chunk_data(std::move(data)); } } catch (...) { batch_buffer_->add_chunk_data(std::current_exception()); } } AT_ASSERT(running_preloaders_.load() > 0); --running_preloaders_; if (running_preloaders_.load() == 0) { // all preloaders are completed, so we can notify the batch_buffer. batch_buffer_->stop(); } } /// Block the current thread until the workers finish execution and exit. void free_workers() { if (!quit_worker_.load()) { quit_worker_ = true; for (auto& worker_thread : preload_threads_) { worker_thread.join(); } } } private: // Templated class that defines what is a chunk and how to read chunk data. // When a chunk is returned by chunk_reader_, ChunkDataset split it into // batches and caches them in batch_buffer_. ChunkReader chunk_reader_; // chunk sampler to shuffle different chunks ChunkSamplerType chunk_sampler_; // example sampler to shuffle examples in a specific chunk ExampleSamplerType example_sampler_; // batch data buffer which holds chunk data from preloading thread. std::shared_ptr> batch_buffer_; // worker thread pool std::vector preload_threads_; /// The options the Dataset was configured with. const ChunkDatasetOptions options_; // function pointer wrapper to apply custom processing over chunk data. This is // considered an advanced parameter for developers who want to apply a // pre-process to the chunk data before sampling into minibatch. // Different than the collate function, this policy is applied on the chunk // level, instead of minibatch level. When a chunk of data is loaded (multiple // chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is // applied to the full loaded data. It is useful if developers want to // perform pre-processing (like bucketing) to the chunk data before // example sampler samples the data. By default it's an empty pointer and no // action will be taken. std::function preprocessing_policy_; // indicate whether the worker thread can be teared down std::atomic quit_worker_; // keep track of running preloaders to notify batch buffer. A value 0 // indicates that the chunk loading is completed. std::atomic running_preloaders_; // mutex to synchronize chunk sampler next() call. mutable std::mutex chunk_index_guard_; // boolean value to indicate whether we need to load the checkpoint for chunk_sampler_. bool load_checkpoint_; }; } // namespace datasets } // namespace data } // namespace torch