#pragma once #include #include #include #include #include namespace torch { namespace data { namespace datasets { /// A dataset of tensors. /// Stores a single tensor internally, which is then indexed inside `get()`. struct TensorDataset : public Dataset { /// Creates a `TensorDataset` from a vector of tensors. explicit TensorDataset(const std::vector& tensors) : TensorDataset(torch::stack(tensors)) {} explicit TensorDataset(torch::Tensor tensor) : tensor(std::move(tensor)) {} /// Returns a single `TensorExample`. TensorExample get(size_t index) override { return tensor[index]; } /// Returns the number of tensors in the dataset. optional size() const override { return tensor.size(0); } Tensor tensor; }; } // namespace datasets } // namespace data } // namespace torch