#pragma once
|
|
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
|
namespace c10 {
|
namespace impl {
|
|
/**
|
* An implementation of DeviceGuardImplInterface which delegates
|
* to virtual dispatch on the DeviceGuardImpl registry.
|
*/
|
class VirtualGuardImpl final : public DeviceGuardImplInterface {
|
public:
|
VirtualGuardImpl(DeviceType device_type)
|
: impl_(getDeviceGuardImpl(device_type)) {}
|
// This constructor exists purely for testing
|
VirtualGuardImpl(const DeviceGuardImplInterface* impl)
|
: impl_(impl) {}
|
|
// Copying and moving is OK!
|
|
DeviceType type() const override {
|
return impl_->type();
|
}
|
Device exchangeDevice(Device d) const override {
|
return impl_->exchangeDevice(d);
|
}
|
Device getDevice() const override {
|
return impl_->getDevice();
|
}
|
void setDevice(Device d) const override {
|
impl_->setDevice(d);
|
}
|
void uncheckedSetDevice(Device d) const noexcept override {
|
impl_->uncheckedSetDevice(d);
|
}
|
Stream getStream(Device d) const noexcept override {
|
return impl_->getStream(d);
|
}
|
Stream getDefaultStream(Device d) const override {
|
return impl_->getDefaultStream(d);
|
}
|
Stream exchangeStream(Stream s) const noexcept override {
|
return impl_->exchangeStream(s);
|
}
|
DeviceIndex deviceCount() const noexcept override {
|
return impl_->deviceCount();
|
}
|
|
// Event functions
|
void record(void** event,
|
const Stream& stream,
|
const DeviceIndex device_index,
|
const EventFlag flag) const override {
|
impl_->record(event, stream, device_index, flag);
|
}
|
void block(
|
void* event,
|
const Stream& stream) const override {
|
impl_->block(event, stream);
|
}
|
bool queryEvent(void* event) const override {
|
return impl_->queryEvent(event);
|
}
|
void destroyEvent(
|
void* event,
|
const DeviceIndex device_index) const noexcept override {
|
impl_->destroyEvent(event, device_index);
|
}
|
|
private:
|
const DeviceGuardImplInterface* impl_ = nullptr;
|
};
|
|
}} // namespace c10::impl
|