#pragma once #include #include #include #include #include #include #include #include namespace torch { namespace data { namespace detail { // For increased safety and more separated logic, this implementation of // `Iterator` consists of a `ValidIterator` and a `SentinelIterator`. A // `ValidIterator` yields new batches until the `DataLoader` is exhausted. While // the `DataLoader` is not exhausted, `ValidIterator`s compare equal if they are // the same object. When the `ValidIterator` becomes exhauted, it compares equal // to the `SentinelIterator`, but not before. Half the code here is to implement // double dispatch for the comparison. Got damnit, C++. template struct ValidIterator; template struct SentinelIterator; /// Base class for the `ValidIterator` and `SentinelIterator` template struct IteratorImpl { virtual ~IteratorImpl() = default; virtual void next() = 0; virtual Batch& get() = 0; virtual bool operator==(const IteratorImpl& other) const = 0; virtual bool operator==(const ValidIterator& other) const = 0; virtual bool operator==(const SentinelIterator& other) const = 0; }; template struct ValidIterator : public IteratorImpl { using BatchProducer = std::function()>; explicit ValidIterator(BatchProducer next_batch) : next_batch_(std::move(next_batch)) {} /// Fetches the next batch. void next() override { // If we didn't get the very first batch yet, get it now. lazy_initialize(); TORCH_CHECK( batch_.has_value(), "Attempted to increment iterator past the end"); // Increment to the next batch. batch_ = next_batch_(); } /// Returns the current batch. The precondition for this operation to not /// throw an exception is that it has been compared to the `SentinelIterator` /// and did not compare equal. Batch& get() override { // If we didn't get the very first batch yet, get it now. lazy_initialize(); TORCH_CHECK( batch_.has_value(), "Attempted to dereference iterator that was past the end"); return batch_.value(); } /// Does double dispatch. bool operator==(const IteratorImpl& other) const override { return other == *this; } /// A `ValidIterator` is equal to the `SentinelIterator` iff. the /// `ValidIterator` has reached the end of the dataloader. bool operator==(const SentinelIterator& /* unused */) const override { lazy_initialize(); return !batch_; } /// Returns true if the memory address of `other` equals that of `this`. bool operator==(const ValidIterator& other) const override { return &other == this; } /// Gets the very first batch if it has not yet been fetched. void lazy_initialize() const { if (!initialized_) { batch_ = next_batch_(); initialized_ = true; } } BatchProducer next_batch_; mutable optional batch_; mutable bool initialized_ = false; }; template struct SentinelIterator : public IteratorImpl { void next() override { AT_ERROR( "Incrementing the DataLoader's past-the-end iterator is not allowed"); } Batch& get() override { AT_ERROR( "Dereferencing the DataLoader's past-the-end iterator is not allowed"); } /// Does double dispatch. bool operator==(const IteratorImpl& other) const override { return other == *this; } /// Calls the comparison operator between `ValidIterator` and /// `SentinelIterator`. bool operator==(const ValidIterator& other) const override { return other == *this; } /// Sentinel iterators always compare equal. bool operator==(const SentinelIterator& other) const override { return true; } }; } // namespace detail template class Iterator { public: // Type aliases to make the class recognized as a proper iterator. using difference_type = std::ptrdiff_t; using value_type = Batch; using pointer = Batch*; using reference = Batch&; using iterator_category = std::input_iterator_tag; explicit Iterator(std::unique_ptr> impl) : impl_(std::move(impl)) {} /// Increments the iterator. /// Only permitted for valid iterators (not past the end). Iterator& operator++() { impl_->next(); return *this; } /// Returns the current batch. /// Only permitted for valid iterators (not past the end). Batch& operator*() { return impl_->get(); } /// Returns a pointer to the current batch. /// Only permitted for valid iterators (not past the end). Batch* operator->() { return &impl_->get(); } /// Compares two iterators for equality. bool operator==(const Iterator& other) const { return *impl_ == *other.impl_; } /// Compares two iterators for inequality. bool operator!=(const Iterator& other) const { return !(*this == other); } private: /// Points either to a `ValidIterator` or to a `SentinelIterator`. std::shared_ptr> impl_; }; } // namespace data } // namespace torch