#pragma once #include #include #include #include #include #include #include #include namespace torch { namespace data { namespace datasets { template class MapDataset; template MapDataset map(D, T); // NOLINT } // namespace datasets } // namespace data } // namespace torch namespace torch { namespace data { namespace datasets { namespace detail { template struct is_optional : std::false_type {}; template struct is_optional> : std::true_type {}; } // namespace detail /// A dataset that can yield data only in batches. template < typename Self, typename Batch = std::vector>, typename BatchRequest = ArrayRef> class BatchDataset { public: using SelfType = Self; using BatchType = Batch; using BatchRequestType = BatchRequest; constexpr static bool is_stateful = detail::is_optional::value; virtual ~BatchDataset() = default; /// Returns a batch of data given an index. virtual Batch get_batch(BatchRequest request) = 0; /// Returns the size of the dataset, or an empty optional if it is unsized. virtual optional size() const = 0; /// Creates a `MapDataset` that applies the given `transform` to this dataset. template MapDataset map(TransformType transform) & { return datasets::map(static_cast(*this), std::move(transform)); } /// Creates a `MapDataset` that applies the given `transform` to this dataset. template MapDataset map(TransformType transform) && { return datasets::map( std::move(static_cast(*this)), std::move(transform)); } }; /// A dataset that can yield data in batches, or as individual examples. /// /// A `Dataset` is a `BatchDataset`, because it supports random access and /// therefore batched access is implemented (by default) by calling the random /// access indexing function for each index in the requested batch of indices. /// This can be customized. template > class Dataset : public BatchDataset> { public: using ExampleType = SingleExample; /// Returns the example at the given index. virtual ExampleType get(size_t index) = 0; /// Returns a batch of data. /// The default implementation calls `get()` for every requested index /// in the batch. std::vector get_batch(ArrayRef indices) override { std::vector batch; batch.reserve(indices.size()); for (const auto i : indices) { batch.push_back(get(i)); } return batch; } }; /// A `StreamDataset` reprsents a dataset that is a potentially infinite stream. /// It takes as batch index only a number, which is the batch size, and yields /// that many elements from the stream. template >> using StreamDataset = BatchDataset; } // namespace datasets } // namespace data } // namespace torch