#pragma once #include #include #include #include #include #include #include namespace c10 { namespace cuda { namespace impl { struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { static constexpr DeviceType static_type = DeviceType::CUDA; CUDAGuardImpl() {} explicit CUDAGuardImpl(DeviceType t) { TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA); } DeviceType type() const override { return DeviceType::CUDA; } Device exchangeDevice(Device d) const override { TORCH_INTERNAL_ASSERT(d.type() == DeviceType::CUDA); Device old_device = getDevice(); if (old_device.index() != d.index()) { C10_CUDA_CHECK(cudaSetDevice(d.index())); } return old_device; } Device getDevice() const override { int device; C10_CUDA_CHECK(cudaGetDevice(&device)); return Device(DeviceType::CUDA, device); } void setDevice(Device d) const override { TORCH_INTERNAL_ASSERT(d.type() == DeviceType::CUDA); C10_CUDA_CHECK(cudaSetDevice(d.index())); } void uncheckedSetDevice(Device d) const noexcept override { C10_CUDA_CHECK_WARN(cudaSetDevice(d.index())); } Stream getStream(Device d) const noexcept override { return getCurrentCUDAStream(d.index()).unwrap(); } Stream getDefaultStream(Device d) const override { return getDefaultCUDAStream(d.index()); } // NB: These do NOT set the current device Stream exchangeStream(Stream s) const noexcept override { CUDAStream cs(s); auto old_stream = getCurrentCUDAStream(s.device().index()); setCurrentCUDAStream(cs); return old_stream.unwrap(); } DeviceIndex deviceCount() const noexcept override { return device_count(); } // Event-related functions void createEvent( cudaEvent_t* cuda_event, const EventFlag flag) const { // Maps PyTorch's Event::Flag to CUDA flag auto cuda_flag = cudaEventDefault; switch (flag) { case EventFlag::PYTORCH_DEFAULT: case EventFlag::CUDA_EVENT_DISABLE_TIMING: cuda_flag = cudaEventDisableTiming; break; case EventFlag::BACKEND_DEFAULT: case EventFlag::CUDA_EVENT_DEFAULT: cuda_flag = cudaEventDefault; break; default: TORCH_CHECK(false, "CUDA event received unknown flag"); } C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag)); } void destroyEvent( void* event, const DeviceIndex device_index) const noexcept override { if (!event) return; auto cuda_event = static_cast(event); int orig_device; C10_CUDA_CHECK_WARN(cudaGetDevice(&orig_device)); C10_CUDA_CHECK_WARN(cudaSetDevice(device_index)); C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event)); C10_CUDA_CHECK_WARN(cudaSetDevice(orig_device)); } void record( void** event, const Stream& stream, const DeviceIndex device_index, const EventFlag flag) const override { TORCH_CHECK(device_index == -1 || device_index == stream.device_index(), "Event device index ", device_index, " does not match recording stream's device index ", stream.device_index(), "."); cudaEvent_t cuda_event = static_cast(*event); CUDAStream cuda_stream{stream}; // Moves to stream's device to record const auto orig_device = getDevice(); setDevice(stream.device()); // Creates the event (lazily) if (!cuda_event) createEvent(&cuda_event, flag); C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream)); // Makes the void* point to the (possibly just allocated) CUDA event *event = cuda_event; // Resets device setDevice(orig_device); } void block( void* event, const Stream& stream) const override { if (!event) return; cudaEvent_t cuda_event = static_cast(event); CUDAStream cuda_stream{stream}; const auto orig_device = getDevice(); setDevice(stream.device()); C10_CUDA_CHECK(cudaStreamWaitEvent( cuda_stream, cuda_event, /*flags (must be zero)=*/ 0)); setDevice(orig_device); } // May be called from any device bool queryEvent(void* event) const override { if (!event) return true; cudaEvent_t cuda_event = static_cast(event); const cudaError_t err = cudaEventQuery(cuda_event); if (err != cudaErrorNotReady) { C10_CUDA_CHECK(err); } return (err == cudaSuccess); } }; }}} // namespace c10::cuda::impl