#pragma once
|
|
#include <c10/util/ArrayRef.h>
|
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAStream.h>
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <vector>
|
|
namespace at { namespace cuda {
|
|
// TODO: Implement this generically in c10. You'll need some way to get
|
// the number of GPUs from the GuardImpl, in that case.
|
class CUDAMultiStreamGuard final {
|
public:
|
/// Calls `set_stream` on each of the streams in the list.
|
/// This may be useful if you need to set different streams
|
/// for different devices.
|
explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams) : CUDAMultiStreamGuard() {
|
for (const auto& s : streams) {
|
setCurrentCUDAStream(s);
|
}
|
}
|
|
CUDAMultiStreamGuard() {
|
const size_t device_count = getNumGPUs();
|
original_streams_.reserve(device_count);
|
for (size_t device = 0; device < device_count; ++device) {
|
original_streams_.push_back(getCurrentCUDAStream(device));
|
}
|
}
|
|
CUDAMultiStreamGuard(const CUDAGuard&) = delete;
|
CUDAMultiStreamGuard& operator=(const CUDAGuard&) = delete;
|
|
// See Note [Move construction for RAII guards is tricky]
|
CUDAMultiStreamGuard(CUDAGuard&& other) = delete;
|
|
// See Note [Move assignment for RAII guards is tricky]
|
CUDAMultiStreamGuard& operator=(CUDAGuard&& other) = delete;
|
|
ArrayRef<CUDAStream> original_streams() const {
|
return original_streams_;
|
}
|
|
/// Resets the CUDA stream on each device to the one that was active upon
|
/// construction.
|
~CUDAMultiStreamGuard() {
|
for (const auto& s : original_streams_) {
|
setCurrentCUDAStream(s);
|
}
|
}
|
|
private:
|
/// The original streams that were active on all devices.
|
std::vector<CUDAStream> original_streams_;
|
};
|
|
}} // namespace at::cuda
|