#pragma once
|
|
// This file provides implementations of InlineDeviceGuard and InlineOptionalDeviceGuard.
|
|
#include <c10/core/Device.h>
|
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
#include <c10/core/impl/VirtualGuardImpl.h>
|
#include <c10/util/Optional.h>
|
#include <c10/util/C++17.h>
|
|
namespace c10 {
|
namespace impl {
|
|
|
|
/**
|
* A DeviceGuard is an RAII class that sets a device to some value
|
* on construction, and resets the device to its original value on
|
* destruction.
|
*
|
* InlineDeviceGuard is a helper class for implementing DeviceGuards.
|
* It is templated over a DeviceGuardImpl (anything that implements
|
* DeviceGuardImplInterface). There are two primary ways to instantiate
|
* InlineDeviceGuard:
|
*
|
* - With a concrete implementation of DeviceGuardImpl, e.g., CUDAGuardImpl.
|
* This is the best way to use InlineDeviceGuard, as all calls are
|
* devirtualized, giving you code as efficient as straight line
|
* calls to cudaGetDevice/cudaSetDevice.
|
*
|
* - With VirtualGuardImpl, which does a virtual dispatch to a DeviceGuardImpl
|
* retrieved from a DeviceType registry. We have explicitly instantiated
|
* InlineDeviceGuard this way as c10::DeviceGuard.
|
*
|
* If you are in a hurry, you can use InlineDeviceGuard directly:
|
*
|
* using CUDAGuard = impl::InlineDeviceGuard<CUDAGuardImpl>;
|
*
|
* However, you can provide a better user experience if you explicitly write a
|
* wrapper class that itself contains the template instantiation:
|
*
|
* class CUDAGuard {
|
* public:
|
* // ... the API ...
|
* private:
|
* impl::InlineDeviceGuard<CUDAGuardImpl> guard_;
|
* }
|
*
|
* The wrapper class provides a good place to write documentation, and helps
|
* avoid weird template instantiation errors when a user incorrectly uses the
|
* class.
|
*
|
* If you need to test this class, consider instantiating it with FakeGuardImpl.
|
*/
|
template <typename T>
|
class InlineDeviceGuard {
|
public:
|
// Note [Omitted default constructor from RAII]
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
// In principle, we could add a default constructor to
|
// DeviceGuard which reads the current device and promises to
|
// restore to that device on exit. However, most cases where you
|
// would have written this, you probably meant to actually just
|
// use OptionalDeviceGuard (since you don't actually need the
|
// restore to happen if you don't ever actually set the device).
|
// We remove the constructor here to encourage you to think about
|
// what you actually want to happen.
|
explicit InlineDeviceGuard() = delete;
|
|
/// Set the current device to the passed Device.
|
explicit InlineDeviceGuard(Device device)
|
: impl_(device.type())
|
, original_device_(device.index() == -1 ? impl_.getDevice() : impl_.exchangeDevice(device))
|
, current_device_(device.index() == -1 ? original_device_ : device)
|
{}
|
|
/// Set the current device index to the passed DeviceIndex. (The
|
/// device type is inferred from the template parameter T).
|
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
|
explicit InlineDeviceGuard(DeviceIndex device_index)
|
: InlineDeviceGuard(Device(U::static_type, device_index)) {}
|
|
/// Construct an InlineDeviceGuard using VirtualGuardImpl with an explicit
|
/// DeviceGuardImplInterface pointer.
|
template <typename U=T, typename=typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type>
|
explicit InlineDeviceGuard(Device device, const DeviceGuardImplInterface* impl)
|
: impl_(VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type())))
|
, original_device_(device.index() == -1 ? impl_.getDevice() : impl_.exchangeDevice(device))
|
, current_device_(device.index() == -1 ? original_device_ : device)
|
{}
|
|
/// Copy is disallowed
|
InlineDeviceGuard(const InlineDeviceGuard<T>&) = delete;
|
InlineDeviceGuard<T>& operator=(const InlineDeviceGuard<T>&) = delete;
|
|
/// Move is disallowed, as DeviceGuard does not have an uninitialized state,
|
/// which is required for moves on types with nontrivial destructors.
|
InlineDeviceGuard(InlineDeviceGuard<T>&& other) = delete;
|
InlineDeviceGuard& operator=(InlineDeviceGuard<T>&& other) = delete;
|
|
~InlineDeviceGuard() {
|
impl_.uncheckedSetDevice(original_device_);
|
}
|
|
/// Sets the device to the given one.
|
template <typename U=T, typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value, int>::type = 0>
|
void set_device(at::Device device) {
|
AT_ASSERT((U::static_type == DeviceType::HIP && device.type() == DeviceType::CUDA) ||
|
device.type() == U::static_type);
|
auto index = device.index();
|
if (index == -1) return;
|
impl_.setDevice(device);
|
current_device_ = device;
|
}
|
|
/// Resets the currently set device to its original device, and then sets the
|
/// current device to the passed device. This is effectively equivalent to
|
/// set_device when a guard supports only a single device type.
|
template <typename U=T>
|
typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value >::type
|
reset_device(at::Device device) {
|
set_device(device);
|
}
|
|
/// Resets the currently set device to its original device, and then sets the
|
/// current device to the passed device (for a possibly different device
|
/// type).
|
///
|
/// This method is named reset_device to highlight the fact that previous
|
/// device settings from this guard are NOT preserved, even if the device
|
/// has a different device type. For example:
|
///
|
/// // CUDA device is 0
|
/// DeviceGuard g(Device(kCUDA, 1));
|
/// g.reset_device(Device(kHIP, 2));
|
/// // CUDA device is 0 (!!)
|
///
|
/// NOTE: this implementation may skip some device setting if it can prove
|
/// that it is unnecessary.
|
///
|
/// Optional argument is for testing only.
|
template <typename U=T>
|
typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value >::type
|
reset_device(at::Device device, const impl::DeviceGuardImplInterface* impl = nullptr) {
|
auto index = device.index();
|
if (index == -1) return;
|
if (device.type() == original_device_.type()) {
|
AT_ASSERT(impl == nullptr || impl->type() == device.type());
|
impl_.setDevice(device);
|
current_device_ = device;
|
} else {
|
// Destruct and reconstruct the DeviceGuard in place
|
impl_.setDevice(original_device_);
|
impl_ = !impl ? VirtualGuardImpl(device.type()) : VirtualGuardImpl(impl);
|
original_device_ = impl_.exchangeDevice(device);
|
current_device_ = device;
|
}
|
}
|
|
/// Sets the device index to the given one. The device type is inferred
|
/// from the original device type.
|
void set_index(DeviceIndex index) {
|
reset_device(Device(original_device_.type(), index));
|
}
|
|
/// Returns the device that was set at the time the most recent
|
/// reset_device(), or otherwise the device at construction time.
|
Device original_device() const {
|
return original_device_;
|
}
|
|
/// Returns the most recent device that was set using this device guard,
|
/// either from construction, or via set_device/reset_device/set_index.
|
Device current_device() const {
|
return current_device_;
|
}
|
|
protected:
|
T impl_;
|
|
private:
|
Device original_device_;
|
Device current_device_;
|
};
|
|
/**
|
* A OptionalDeviceGuard is an RAII class that sets a device to some value on
|
* initialization, and resets the device to its original value on destruction.
|
*
|
* InlineOptionalDeviceGuard is a helper class for implementing
|
* OptionalDeviceGuards. See guidance in InlineDeviceGuard on how to
|
* use this. See OptionalDeviceGuard for user-oriented usage notes.
|
*/
|
template <typename T>
|
class InlineOptionalDeviceGuard {
|
public:
|
// Note [Explicit initialization of optional fields]
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
// Explicit initialization of optional fields
|
// required to workaround an nvcc bug; see https://github.com/pytorch/pytorch/issues/12117
|
|
/// Creates an uninitialized OptionalDeviceGuard.
|
explicit InlineOptionalDeviceGuard()
|
: guard_() // See Note [Explicit initialization of optional fields]
|
{}
|
|
/// Set the current device to the passed Device, if it is not nullopt.
|
explicit InlineOptionalDeviceGuard(optional<Device> device_opt)
|
: guard_() { // See Note [Explicit initialization of optional fields]
|
if (device_opt.has_value()) {
|
guard_.emplace(device_opt.value());
|
}
|
}
|
|
/// Set the current device to the passed DeviceIndex, if it is not nullopt.
|
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
|
explicit InlineOptionalDeviceGuard(optional<DeviceIndex> device_index_opt)
|
: guard_() { // See Note [Explicit initialization of optional fields]
|
if (device_index_opt.has_value()) {
|
guard_.emplace(device_index_opt.value());
|
}
|
}
|
|
/// All constructors of DeviceGuard are valid for OptionalDeviceGuard
|
/// and result in initialized OptionalDeviceGuard.
|
template <typename... Args>
|
explicit InlineOptionalDeviceGuard(Args&&... args)
|
: guard_(in_place, std::forward<Args>(args)...) {}
|
|
// TODO: Consider readding Tensor and TensorList constructors here, when
|
// Tensor moves to c10. (These are only valid on OptionalDeviceGuard,
|
// because a Tensor may be undefined, in which case we need an uninitialized
|
// tensor guard.)
|
|
// Note [Move construction for RAII guards is tricky]
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
// In principle, move construction is useful for terminating
|
// the lifetime of a `OptionalDeviceGuard` early; for example:
|
//
|
// // current device is d0
|
// OptionalDeviceGuard g1(d1);
|
// // current device is d1
|
// {
|
// OptionalDeviceGuard g2(std::move(g1));
|
// }
|
// // current device is d0!!
|
//
|
// However, it's difficult to implement the move constructor
|
// in a way that works in all situations. For example, consider
|
// the following example:
|
//
|
// OptionalDeviceGuard g1(d1);
|
// {
|
// OptionalDeviceGuard g2(d2);
|
// {
|
// OptionalDeviceGuard g3(std::move(g1)); // !!!
|
// }
|
// }
|
//
|
// What should the current device be while g3 in scope... and what
|
// should it be after it goes out of scope? What about g2?
|
// There don't seem to be satisfactory answers for these questions.
|
//
|
// It's in principle possible to raise an error when this occurs
|
// by doing some extra thread-local bookkeeping. But why bother?
|
// Just don't provide the constructor.
|
InlineOptionalDeviceGuard(InlineOptionalDeviceGuard<T>&& other) = delete;
|
|
// Note [Move assignment for RAII guards is tricky]
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
// Move assignment is deleted, because you need to know which guard was
|
// defined "first", as that guard's original_device_ wins--with the current
|
// representation, we have no way of telling which is the case. (Move
|
// construction does not have this problem, as one guard is always
|
// uninitialized.)
|
//
|
// We can make this clear by way of a pair of examples:
|
//
|
// Example 1:
|
//
|
// // initial device is n0
|
// {
|
// CUDAGuard g1(n1);
|
// {
|
// CUDAGuard g2(n2);
|
// // current device should be n2
|
// g1 = std::move(g2);
|
// // current device should still be n2
|
// }
|
// // current device should still be n2
|
// }
|
// // current device should be n0
|
//
|
// Example 2 (flip the order of the two guards):
|
//
|
// // initial device is n0
|
// {
|
// CUDAGuard g2(n2);
|
// {
|
// CUDAGuard g1(n1);
|
// // current device should be n1
|
// g1 = std::move(g2);
|
// // current device should be n2
|
// }
|
// // current device should be n0 (since g2 has been vacated)
|
// }
|
//
|
// In both examples, we need g1 to restore to n0 after move assignment.
|
// However, in example 1, this is determined by the restore value of g1
|
// (prior to the move). In example 2, however, it is determined by the the
|
// restore value of g2(!!). We don't know which one should win, without having
|
// a way of telling which guard was allocated first.
|
//
|
// We could solve this with an extra thread-local variable. But no one is
|
// actually using move-assignment. So just get rid of it.
|
InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) = delete;
|
|
/// Sets the device to the given one. Initializes OptionalDeviceGuard if it
|
/// is not already initialized.
|
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
|
void set_device(at::Device device) {
|
if (!guard_.has_value()) {
|
guard_.emplace(device);
|
} else {
|
guard_->set_device(device);
|
}
|
}
|
|
/// Resets the currently set device to its original device, and then sets the
|
/// current device to the passed device (for a possibly different device
|
/// type). Initializes OptionalDeviceGuard if it is not already initialized.
|
///
|
/// See notes on why this is called reset_device on InlineDeviceGuard.
|
///
|
/// Optional argument is for testing only.
|
template <typename U=T, typename=typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type>
|
void reset_device(at::Device device, const DeviceGuardImplInterface* impl = nullptr) {
|
if (!guard_.has_value()) {
|
guard_.emplace(device, impl);
|
} else {
|
guard_->reset_device(device, impl);
|
}
|
}
|
|
/// Resets the currently set device to its original device, and then sets the
|
/// current device to the passed device. Initializes the guard if it is
|
/// not already initialized. This is effectively equivalent to set_device
|
/// when a guard supports only a single device type.
|
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type>
|
void reset_device(at::Device device) {
|
if (!guard_.has_value()) {
|
guard_.emplace(device);
|
} else {
|
guard_->reset_device(device);
|
}
|
}
|
|
/// Sets the device index to the given one. The device type is statically
|
/// known.
|
template <typename U=T, typename=typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value >::type>
|
void set_index(DeviceIndex index) {
|
if (!guard_.has_value()) {
|
guard_.emplace(index);
|
} else {
|
guard_->set_index(index);
|
}
|
}
|
|
/// Returns the device that was set immediately prior to initialization of the,
|
/// guard, or nullopt if the guard is uninitialized.
|
optional<Device> original_device() const {
|
return guard_.has_value() ? make_optional(guard_->original_device()) : nullopt;
|
}
|
|
/// Returns the most recent device that was set using this device guard,
|
/// either from construction, or via set_device, if the guard is initialized,
|
/// or nullopt if the guard is uninitialized.
|
optional<Device> current_device() const {
|
return guard_.has_value() ? make_optional(guard_->current_device()) : nullopt;
|
}
|
|
/// Restore the original device, resetting this guard to uninitialized state.
|
void reset() {
|
guard_.reset();
|
}
|
|
private:
|
optional<InlineDeviceGuard<T>> guard_;
|
};
|
|
}} // namespace c10::impl
|