#pragma once #include namespace c10 { namespace impl { /** * An implementation of DeviceGuardImplInterface which delegates * to virtual dispatch on the DeviceGuardImpl registry. */ class VirtualGuardImpl final : public DeviceGuardImplInterface { public: VirtualGuardImpl(DeviceType device_type) : impl_(getDeviceGuardImpl(device_type)) {} // This constructor exists purely for testing VirtualGuardImpl(const DeviceGuardImplInterface* impl) : impl_(impl) {} // Copying and moving is OK! DeviceType type() const override { return impl_->type(); } Device exchangeDevice(Device d) const override { return impl_->exchangeDevice(d); } Device getDevice() const override { return impl_->getDevice(); } void setDevice(Device d) const override { impl_->setDevice(d); } void uncheckedSetDevice(Device d) const noexcept override { impl_->uncheckedSetDevice(d); } Stream getStream(Device d) const noexcept override { return impl_->getStream(d); } Stream getDefaultStream(Device d) const override { return impl_->getDefaultStream(d); } Stream exchangeStream(Stream s) const noexcept override { return impl_->exchangeStream(s); } DeviceIndex deviceCount() const noexcept override { return impl_->deviceCount(); } // Event functions void record(void** event, const Stream& stream, const DeviceIndex device_index, const EventFlag flag) const override { impl_->record(event, stream, device_index, flag); } void block( void* event, const Stream& stream) const override { impl_->block(event, stream); } bool queryEvent(void* event) const override { return impl_->queryEvent(event); } void destroyEvent( void* event, const DeviceIndex device_index) const noexcept override { impl_->destroyEvent(event, device_index); } private: const DeviceGuardImplInterface* impl_ = nullptr; }; }} // namespace c10::impl