#pragma once #include #include #include #include namespace torch { namespace serialize { class OutputArchive; class InputArchive; } // namespace serialize } // namespace torch namespace torch { namespace data { namespace samplers { /// A `Sampler` that selects a subset of indices to sample from and defines a /// sampling behavior. In a distributed setting, this selects a subset of the /// indices depending on the provided num_replicas and rank parameters. The /// `Sampler` performs a rounding operation based on the `allow_duplicates` /// parameter to decide the local sample count. template > class DistributedSampler : public Sampler { public: DistributedSampler( size_t size, size_t num_replicas = 1, size_t rank = 0, bool allow_duplicates = true) : size_(size), num_replicas_(num_replicas), rank_(rank), epoch_(0), allow_duplicates_(allow_duplicates) {} /// Set the epoch for the current enumeration. This can be used to alter the /// sample selection and shuffling behavior. void set_epoch(size_t epoch) { epoch_ = epoch; } size_t epoch() const { return epoch_; } protected: size_t local_sample_count() { if (allow_duplicates_) { return (size_ + num_replicas_ - 1) / num_replicas_; } else { return size_ / num_replicas_; } } size_t size_; size_t num_replicas_; size_t rank_; size_t epoch_; bool allow_duplicates_; }; /// Select samples randomly. The sampling order is shuffled at each `reset()` /// call. class TORCH_API DistributedRandomSampler : public DistributedSampler<> { public: DistributedRandomSampler( size_t size, size_t num_replicas = 1, size_t rank = 0, bool allow_duplicates = true); /// Resets the `DistributedRandomSampler` to a new set of indices. void reset(optional new_size = nullopt) override; /// Returns the next batch of indices. optional> next(size_t batch_size) override; /// Serializes the `DistributedRandomSampler` to the `archive`. void save(serialize::OutputArchive& archive) const override; /// Deserializes the `DistributedRandomSampler` from the `archive`. void load(serialize::InputArchive& archive) override; /// Returns the current index of the `DistributedRandomSampler`. size_t index() const noexcept; private: void populate_indices(); size_t begin_index_; size_t end_index_; size_t sample_index_; std::vector all_indices_; }; /// Select samples sequentially. class TORCH_API DistributedSequentialSampler : public DistributedSampler<> { public: DistributedSequentialSampler( size_t size, size_t num_replicas = 1, size_t rank = 0, bool allow_duplicates = true); /// Resets the `DistributedSequentialSampler` to a new set of indices. void reset(optional new_size = nullopt) override; /// Returns the next batch of indices. optional> next(size_t batch_size) override; /// Serializes the `DistributedSequentialSampler` to the `archive`. void save(serialize::OutputArchive& archive) const override; /// Deserializes the `DistributedSequentialSampler` from the `archive`. void load(serialize::InputArchive& archive) override; /// Returns the current index of the `DistributedSequentialSampler`. size_t index() const noexcept; private: void populate_indices(); size_t begin_index_; size_t end_index_; size_t sample_index_; std::vector all_indices_; }; } // namespace samplers } // namespace data } // namespace torch