#pragma once #include #include #include #include #include #include #include #include namespace torch { namespace autograd { //TODO: change it to TORCH_API when we merge the libs struct AT_CUDA_API Scatter : public Node { explicit Scatter( std::vector devices, const c10::optional>& chunk_sizes = c10::nullopt, int64_t dim = 0, const c10::optional>>& streams = c10::nullopt, bool unsqueeze_scalars = false); ~Scatter() override; variable_list apply(variable_list&& inputs) override; std::vector devices_; c10::optional> chunk_sizes_; int64_t dim_; c10::optional>> streams_; bool unsqueeze_scalars_; }; struct AT_CUDA_API Gather : public Node { explicit Gather(const at::Device& destination_device, int64_t dim = 0); ~Gather() override; variable_list apply(variable_list&& inputs) override; at::Device destination_device_; int64_t dim_; }; } // namespace autograd } // namespace torch