#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace at { class Tensor; class CAFFE2_API Context { public: Context(); Generator & defaultGenerator(Device device) { DeviceType device_type = device.type(); initCUDAIfNeeded(device_type); initHIPIfNeeded(device_type); if (device_type == at::kCPU) { return *at::detail::getDefaultCPUGenerator(); } else if (device_type == at::kCUDA) { return *at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index()); } else { AT_ERROR(DeviceTypeName(device_type), " device type not enabled."); } } Device getDeviceFromPtr(void* data, DeviceType device_type) { initCUDAIfNeeded(device_type); initHIPIfNeeded(device_type); if (device_type == at::kCPU) { return DeviceType::CPU; } else if (device_type == at::kCUDA) { return at::detail::getCUDAHooks().getDeviceFromPtr(data); } else { AT_ERROR(DeviceTypeName(device_type), " device type not enabled."); } } bool isPinnedPtr(void* data) { return detail::getCUDAHooks().isPinnedPtr(data); } bool hasOpenMP() const; bool hasMKL() const; bool hasLAPACK() const; bool hasMKLDNN() const; bool hasMAGMA() const { return detail::getCUDAHooks().hasMAGMA(); } bool hasCUDA() const { return detail::getCUDAHooks().hasCUDA(); } bool hasHIP() const { return detail::getHIPHooks().hasHIP(); } bool hasXLA() const { return c10::impl::hasDeviceGuardImpl(at::DeviceType::XLA); } // defined in header so that getNonVariableType has ability to inline // call_once check. getNonVariableType is called fairly frequently THCState* lazyInitCUDA() { std::call_once(thc_init,[&] { thc_state = detail::getCUDAHooks().initCUDA(); }); return thc_state.get(); } THHState* lazyInitHIP() { std::call_once(thh_init,[&] { thh_state = detail::getHIPHooks().initHIP(); }); return thh_state.get(); } const at::cuda::NVRTC& getNVRTC() { return detail::getCUDAHooks().nvrtc(); } THCState* getTHCState() { // AT_ASSERT(thc_state); return thc_state.get(); } THHState* getTHHState() { return thh_state.get(); } bool setFlushDenormal(bool on); // NB: This method is *purely* whether or not a user requested // that CuDNN was enabled, it doesn't actually say anything about // whether or not CuDNN is actually usable. Use cudnn_is_acceptable // to test this instead bool userEnabledCuDNN() const; void setUserEnabledCuDNN(bool e); bool userEnabledMkldnn() const; void setUserEnabledMkldnn(bool e); bool benchmarkCuDNN() const; void setBenchmarkCuDNN(bool); bool deterministicCuDNN() const; void setDeterministicCuDNN(bool); at::QEngine qEngine() const; void setQEngine(at::QEngine e); const std::vector& supportedQEngines() const; private: void initCUDAIfNeeded(DeviceType p) { if (p == DeviceType::CUDA) { lazyInitCUDA(); } } void initHIPIfNeeded(DeviceType p) { if (p == DeviceType::HIP) { lazyInitHIP(); } } std::once_flag thc_init; std::once_flag thh_init; bool enabled_cudnn = true; bool deterministic_cudnn = false; bool benchmark_cudnn = false; bool enabled_mkldnn = true; c10::optional quantized_engine = c10::nullopt; std::unique_ptr thc_state; std::unique_ptr thh_state; }; CAFFE2_API Context& globalContext(); static inline void init() { globalContext(); } CAFFE2_API Allocator* getCPUAllocator(); static inline DeprecatedTypeProperties& getNonVariableDeprecatedTypeProperties(Backend p, ScalarType s) { return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( p, s, /*is_variable*/false); } static inline DeprecatedTypeProperties& CPU(ScalarType s) { return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( Backend::CPU, s, /*is_variable*/false); } static inline DeprecatedTypeProperties& CUDA(ScalarType s) { return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( Backend::CUDA, s, /*is_variable*/false); } static inline DeprecatedTypeProperties& HIP(ScalarType s) { return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( Backend::HIP, s, /*is_variable*/false); } static inline bool hasCUDA() { return globalContext().hasCUDA(); } static inline bool hasHIP() { return globalContext().hasHIP(); } static inline bool hasXLA() { return globalContext().hasXLA(); } // Despite its name, this function returns the number of *CUDA* GPUs. static inline size_t getNumGPUs() { // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS // FUNCTION. If you are interested in interrogating the number of // devices for a specific device type, add that function to the // relevant library (e.g., similar to at::cuda::device_count()) if (hasCUDA() && hasHIP()) { throw std::runtime_error( "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades " "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually " "means HIP. Rebuild PyTorch with one or the other disabled."); } else if (hasCUDA()) { return detail::getCUDAHooks().getNumGPUs(); } else if (hasHIP()) { return detail::getHIPHooks().getNumGPUs(); } else { return 0; } } static inline bool hasOpenMP() { return globalContext().hasOpenMP(); } static inline bool hasMKL() { return globalContext().hasMKL(); } static inline bool hasLAPACK() { return globalContext().hasLAPACK(); } static inline bool hasMAGMA() { return globalContext().hasMAGMA(); } static inline bool hasMKLDNN() { return globalContext().hasMKLDNN(); } static inline void manual_seed(uint64_t seed) { auto& gen = globalContext().defaultGenerator(DeviceType::CPU); { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex_); gen.set_current_seed(seed); } // NB: Sometimes we build with CUDA, but we don't have any GPUs // available. In that case, we must not seed CUDA; it will fail! int num_gpus = detail::getCUDAHooks().getNumGPUs(); if (hasCUDA() && num_gpus > 0) { for (int i = 0; i < num_gpus; i++) { auto& cuda_gen = globalContext().defaultGenerator(Device(at::kCUDA, i)); { // See Note [Acquire lock when using random generators] std::lock_guard lock(cuda_gen.mutex_); cuda_gen.set_current_seed(seed); } } } } } // namespace at