#pragma once #include #include #include namespace torch { namespace data { namespace datasets { /// A dataset that wraps another dataset in a shared pointer and implements the /// `BatchDataset` API, delegating all calls to the shared instance. This is /// useful when you want all worker threads in the dataloader to access the same /// dataset instance. The dataset must take care of synchronization and /// thread-safe access itself. /// /// Use `torch::data::datasets::make_shared_dataset()` to create a new /// `SharedBatchDataset` like you would a `std::shared_ptr`. template class SharedBatchDataset : public BatchDataset< SharedBatchDataset, typename UnderlyingDataset::BatchType, typename UnderlyingDataset::BatchRequestType> { public: using BatchType = typename UnderlyingDataset::BatchType; using BatchRequestType = typename UnderlyingDataset::BatchRequestType; /// Constructs a new `SharedBatchDataset` from a `shared_ptr` to the /// `UnderlyingDataset`. /* implicit */ SharedBatchDataset( std::shared_ptr shared_dataset) : dataset_(std::move(shared_dataset)) {} /// Calls `get_batch` on the underlying dataset. BatchType get_batch(BatchRequestType request) override { return dataset_->get_batch(std::move(request)); } /// Returns the `size` from the underlying dataset. optional size() const override { return dataset_->size(); } /// Accesses the underlying dataset. UnderlyingDataset& operator*() { return *dataset_; } /// Accesses the underlying dataset. const UnderlyingDataset& operator*() const { return *dataset_; } /// Accesses the underlying dataset. UnderlyingDataset* operator->() { return dataset_.get(); } /// Accesses the underlying dataset. const UnderlyingDataset* operator->() const { return dataset_.get(); } /// Calls `reset()` on the underlying dataset. void reset() { dataset_->reset(); } private: std::shared_ptr dataset_; }; /// Constructs a new `SharedBatchDataset` by creating a /// `shared_ptr`. All arguments are forwarded to /// `make_shared`. template SharedBatchDataset make_shared_dataset(Args&&... args) { return std::make_shared(std::forward(args)...); } } // namespace datasets } // namespace data } // namespace torch