#pragma once #include #include #include #include #include #include namespace torch { namespace data { namespace datasets { namespace detail { template using optional_if_t = typename std::conditional, T>::type; } // namespace detail /// A `MapDataset` is a dataset that applies a transform to a source dataset. template class MapDataset : public BatchDataset< MapDataset, detail::optional_if_t< SourceDataset::is_stateful, typename AppliedTransform::OutputBatchType>, typename SourceDataset::BatchRequestType> { public: using DatasetType = SourceDataset; using TransformType = AppliedTransform; using BatchRequestType = typename SourceDataset::BatchRequestType; using OutputBatchType = detail::optional_if_t< SourceDataset::is_stateful, typename AppliedTransform::OutputBatchType>; MapDataset(DatasetType dataset, TransformType transform) : dataset_(std::move(dataset)), transform_(std::move(transform)) {} /// Gets a batch from the source dataset and applies the transform to it, /// returning the result. OutputBatchType get_batch(BatchRequestType indices) override { return get_batch_impl(std::move(indices)); } /// Returns the size of the source dataset. optional size() const noexcept override { return dataset_.size(); } /// Calls `reset()` on the underlying dataset. /// NOTE: Stateless datasets do not have a reset() method, so a call to this /// method will only compile for stateful datasets (which have a reset() /// method). void reset() { dataset_.reset(); } /// Returns the underlying dataset. const SourceDataset& dataset() noexcept { return dataset_; } /// Returns the transform being applied. const AppliedTransform& transform() noexcept { return transform_; } private: /// The implementation of `get_batch()` for the stateless case, which simply /// applies the transform to the output of `get_batch()` from the dataset. template < typename D = SourceDataset, typename = torch::disable_if_t> OutputBatchType get_batch_impl(BatchRequestType indices) { return transform_.apply_batch(dataset_.get_batch(std::move(indices))); } /// The implementation of `get_batch()` for the stateful case. Here, we follow /// the semantics of `Optional.map()` in many functional languages, which /// applies a transformation to the optional's content when the optional /// contains a value, and returns a new optional (of a different type) if the /// original optional returned by `get_batch()` was empty. template torch::enable_if_t get_batch_impl( BatchRequestType indices) { if (auto batch = dataset_.get_batch(std::move(indices))) { return transform_.apply_batch(std::move(*batch)); } return nullopt; } /// The underlying dataset being transformed. SourceDataset dataset_; // The transformation that is applied to batches received from the dataset. AppliedTransform transform_; }; /// Creates a `MapDataset` with the given dataset and transform. template MapDataset map( DatasetType dataset, TransformType transform) { static_assert( std::is_same< typename std::conditional< DatasetType::is_stateful, typename DatasetType::BatchType::value_type, typename DatasetType::BatchType>::type, typename TransformType::InputBatchType>::value, "BatchType type of dataset does not match input type of transform"); return {std::move(dataset), std::move(transform)}; } } // namespace datasets } // namespace data } // namespace torch