#pragma once #include #include #include #include #include namespace torch { namespace data { namespace transforms { template > struct Stack; /// A `Collation` for `Example` types that stacks all data /// tensors into one tensor, and all target (label) tensors into one tensor. template <> struct Stack> : public Collation> { Example<> apply_batch(std::vector> examples) override { std::vector data, targets; data.reserve(examples.size()); targets.reserve(examples.size()); for (auto& example : examples) { data.push_back(std::move(example.data)); targets.push_back(std::move(example.target)); } return {torch::stack(data), torch::stack(targets)}; } }; /// A `Collation` for `Example` types that stacks all data /// tensors into one tensor. template <> struct Stack : public Collation> { TensorExample apply_batch(std::vector examples) override { std::vector data; data.reserve(examples.size()); for (auto& example : examples) { data.push_back(std::move(example.data)); } return torch::stack(data); } }; } // namespace transforms } // namespace data } // namespace torch