#pragma once #include namespace c10 { namespace impl { /** * A StreamGuard is an RAII class that changes the current device * to the device corresponding to some stream, and changes the * default stream on that device to be this stream. * * InlineStreamGuard is a helper class for implementing StreamGuards. * See InlineDeviceGuard for guidance on how to use this class. */ template class InlineStreamGuard : private InlineDeviceGuard { public: /// No default constructor, see Note [Omitted default constructor from RAII] explicit InlineStreamGuard() = delete; /// Set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream. explicit InlineStreamGuard(Stream stream) : InlineDeviceGuard(stream.device()) , original_stream_of_original_device_(this->impl_.getStream(original_device())) , original_stream_of_current_device_(this->impl_.exchangeStream(stream)) , current_stream_(stream) {} /// This constructor exists purely for testing template ::value>::type> explicit InlineStreamGuard(Stream stream, const DeviceGuardImplInterface* impl) : InlineDeviceGuard(stream.device(), impl ? impl : getDeviceGuardImpl(stream.device_type())) , original_stream_of_original_device_(this->impl_.getStream(original_device())) , original_stream_of_current_device_(this->impl_.exchangeStream(stream)) , current_stream_(stream) {} /// Copy is disallowed InlineStreamGuard(const InlineStreamGuard&) = delete; InlineStreamGuard& operator=(const InlineStreamGuard&) = delete; /// Move is disallowed, as StreamGuard does not have an uninitialized state, /// which is required for moves on types with nontrivial destructors. InlineStreamGuard(InlineStreamGuard&& other) = delete; InlineStreamGuard& operator=(InlineStreamGuard&& other) = delete; ~InlineStreamGuard() { this->impl_.exchangeStream(original_stream_of_current_device_); } /// Resets the currently set stream to the original stream and /// the currently set device to the original device. Then, /// set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream. /// /// NOTE: this implementation may skip some stream/device setting if /// it can prove that it is unnecessary. /// /// WARNING: reset_stream does NOT preserve previously set streams on /// different devices. If you need to set streams on multiple devices /// on CUDA, use CUDAMultiStreamGuard instead. void reset_stream(Stream stream) { // TODO: make a version that takes an impl argument. Unfortunately, // that will require SFINAE because impl is only valid for the // VirtualGuardImpl specialization. if (stream.device() == this->current_device()) { this->impl_.exchangeStream(stream); current_stream_ = stream; } else { // Destruct and reconstruct the StreamGuard in-place this->impl_.exchangeStream(original_stream_of_current_device_); this->reset_device(stream.device()); original_stream_of_current_device_ = this->impl_.exchangeStream(stream); current_stream_ = stream; } } // It's not clear if set_device should also reset the current stream // if the device is unchanged; therefore, we don't provide it. // The situation is somewhat clearer with reset_device, but it's still // a pretty weird thing to do, so haven't added this either. /// Returns the stream of the original device prior to this guard. Subtly, /// the stream returned here is the original stream of the *original* /// device; i.e., it's the stream that your computation *would* have /// been put on, if it hadn't been for this meddling stream guard. /// This is usually what you want. Stream original_stream() const { return original_stream_of_original_device_; } /// Returns the most recent stream that was set using this device guard, /// either from construction, or via set_stream. Stream current_stream() const { return current_stream_; } /// Returns the most recent device that was set using this device guard, /// either from construction, or via set_device/reset_device/set_index. Device current_device() const { return InlineDeviceGuard::current_device(); } /// Returns the device that was set at the most recent reset_stream(), /// or otherwise the device at construction time. Device original_device() const { return InlineDeviceGuard::original_device(); } private: Stream original_stream_of_original_device_; // what the user probably cares about Stream original_stream_of_current_device_; // what we need to restore Stream current_stream_; }; /** * An OptionalStreamGuard is an RAII class that sets a device to some value on * initialization, and resets the device to its original value on destruction. * See InlineOptionalDeviceGuard for more guidance on how to use this class. */ template class InlineOptionalStreamGuard { public: /// Creates an uninitialized stream guard. explicit InlineOptionalStreamGuard() : guard_() // See Note [Explicit initialization of optional fields] {} /// Set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream, /// if the passed stream is not nullopt. explicit InlineOptionalStreamGuard(optional stream_opt) : guard_() { if (stream_opt.has_value()) { guard_.emplace(stream_opt.value()); } } /// All constructors of StreamGuard are valid for OptionalStreamGuard template explicit InlineOptionalStreamGuard(Args&&... args) : guard_(in_place, std::forward(args)...) {} // See Note [Move construction for RAII guards is tricky] InlineOptionalStreamGuard(InlineOptionalStreamGuard&& other) = delete; // See Note [Move assignment for RAII guards is tricky] InlineOptionalStreamGuard& operator=(InlineOptionalStreamGuard&& other) = delete; /// Resets the currently set stream to the original stream and /// the currently set device to the original device. Then, /// set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream. /// Initializes the OptionalStreamGuard if it was not previously initialized. void reset_stream(Stream stream) { if (guard_.has_value()) { guard_->reset_stream(stream); } else { guard_.emplace(stream); } } /// Returns the stream that was set at the time the guard was most recently /// initialized, or nullopt if the guard is uninitialized. optional original_stream() const { return guard_.has_value() ? make_optional(guard_->original_stream()) : nullopt; } /// Returns the most recent stream that was set using this stream guard, /// either from construction, or via reset_stream, if the guard is initialized, /// or nullopt if the guard is uninitialized. optional current_stream() const { return guard_.has_value() ? make_optional(guard_->current_stream()) : nullopt; } /// Restore the original device and stream, resetting this guard to uninitialized state. void reset() { guard_.reset(); } private: optional> guard_; }; }} // namespace c10::impl