#pragma once #include #include #include #include #include #include #include namespace torch { namespace cuda { using tensor_list2d = std::vector>; TORCH_API std::vector broadcast(const at::Tensor& tensor, at::IntArrayRef devices); TORCH_API tensor_list2d broadcast_coalesced(at::TensorList tensors, at::IntArrayRef devices, size_t buffer_size); TORCH_API std::vector scatter( const at::Tensor& tensor, at::IntArrayRef devices, const c10::optional>& chunk_sizes = c10::nullopt, int64_t dim = 0, const c10::optional>>& streams = c10::nullopt); TORCH_API at::Tensor gather( at::TensorList tensors, int64_t dim, c10::optional destination_index); }}