#pragma once // This header provides C++ wrappers around commonly used CUDA API functions. // The benefit of using C++ here is that we can raise an exception in the // event of an error, rather than explicitly pass around error codes. This // leads to more natural APIs. // // The naming convention used here matches the naming convention of torch.cuda #include #include #include #include namespace c10 { namespace cuda { inline DeviceIndex device_count() noexcept { int count; // NB: In the past, we were inconsistent about whether or not this reported // an error if there were driver problems are not. Based on experience // interacting with users, it seems that people basically ~never want this // function to fail; it should just return zero if things are not working. // Oblige them. cudaError_t err = cudaGetDeviceCount(&count); if (err != cudaSuccess) { // Clear out the error state, so we don't spuriously trigger someone else. // (This shouldn't really matter, since we won't be running very much CUDA // code in this regime.) cudaError_t last_err = cudaGetLastError(); (void)last_err; return 0; } return static_cast(count); } inline DeviceIndex current_device() { int cur_device; C10_CUDA_CHECK(cudaGetDevice(&cur_device)); return static_cast(cur_device); } inline void set_device(DeviceIndex device) { C10_CUDA_CHECK(cudaSetDevice(static_cast(device))); } }} // namespace c10::cuda