#ifndef CAFFE2_CORE_EVENT_H_ #define CAFFE2_CORE_EVENT_H_ #include #include #include "caffe2/core/common.h" #include "caffe2/core/logging.h" #include "caffe2/proto/caffe2_pb.h" namespace caffe2 { constexpr int MaxDeviceTypes = DeviceTypeProto::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES; class Event; enum EventStatus { EVENT_INITIALIZED = 0, EVENT_SCHEDULED = 1, EVENT_SUCCESS = 2, EVENT_FAILED = 3, }; // For the following functions, void* shall be interpreted as the corresponding // context object corresponding to the device type associated with the // functions. // Initializes event typedef void (*EventCreateFunction)(const DeviceOption& option, Event*); // Called on event to signal that CPU part of operation is finished, // Optionally accepts error message from CPU part. // Should be called no more than once per event typedef void (*EventRecordFunction)(Event*, const void*, const char*); // Waits and returns as soon as possible in order schedule next operation, // e.g. for CUDA->CUDA waits only for CPU part of CUDA op, // for CUDA->CPU waits till the CUDA op is fully completed. // Prepares context to synchronize device part of operation. // Can be called concurrently from multiple threads typedef void (*EventWaitFunction)(const Event*, void*); // Waits till operation is fully finished, // can be called concurrently from multiple threads typedef void (*EventFinishFunction)(const Event*); // Queries current status of operation, // can be called concurrently from multiple threads typedef EventStatus (*EventQueryFunction)(const Event*); typedef const std::string& (*EventErrorMessageFunction)(const Event*); typedef void (*EventSetFinishedFunction)(const Event*, const char*); typedef void (*EventResetFunction)(Event*); // Sets callback that is called when event is finished typedef std::function EventCallbackFunction; typedef void (*EventSetCallbackFunction)(Event*, EventCallbackFunction); class CAFFE2_API Event { public: explicit Event(const DeviceOption& option) : event_(), type_(option.device_type()), option_(option) { CAFFE_ENFORCE_LT(type_, MaxDeviceTypes); CAFFE_ENFORCE(event_creator_[type_]); event_creator_[type_](option, this); } // Nothing needs to be done in the destructor, as the event creator should // set the proper destruction process for the unique_ptr. ~Event() {} void Record( DeviceType recorder_type, const void* context, const char* err_msg = nullptr) { auto recorder_index = TypeToProto(recorder_type); CAFFE_ENFORCE_EQ( recorder_index, type_, "You are trying to record with a wrong device type."); CAFFE_ENFORCE(event_recorder_[recorder_index]); event_recorder_[recorder_index](this, context, err_msg); } void Wait(DeviceType waiter_type, void* context) const { auto waiter_index = TypeToProto(waiter_type); CAFFE_ENFORCE(event_waiter_[waiter_index][type_]); event_waiter_[waiter_index][type_](this, context); } void Finish() const { CAFFE_ENFORCE(event_finisher_[type_]); event_finisher_[type_](this); } EventStatus Query() const { CAFFE_ENFORCE(event_querier_[type_]); return event_querier_[type_](this); } const std::string& ErrorMessage() const { CAFFE_ENFORCE(event_err_msg_getter_[type_]); return event_err_msg_getter_[type_](this); } void Reset() { CAFFE_ENFORCE(event_resetter_[type_]); event_resetter_[type_](this); #ifdef CAFFE2_USE_EXCEPTION_PTR caught_exception_ = nullptr; exception_timestamp_ = 0; #endif // CAFFE2_USE_EXCEPTION_PTR } const DeviceOption& GetDeviceOption() const { return option_; } bool IsScheduled() const { return Query() == EventStatus::EVENT_SCHEDULED; } bool IsFinished() const { auto status = Query(); return status == EventStatus::EVENT_SUCCESS || status == EventStatus::EVENT_FAILED; } void SetFinished(const char* err_msg = nullptr) { CAFFE_ENFORCE(event_finished_setter_[type_]); return event_finished_setter_[type_](this, err_msg); } bool SupportsCallback() const { return event_callback_setter_[type_] != nullptr; } void SetCallback(EventCallbackFunction callback) { CAFFE_ENFORCE( event_callback_setter_[type_], "Event does not support callbacks"); event_callback_setter_[type_](this, callback); } // If parent op has succeeded, then we can run any child op; // If parent op is in scheduled state, we need to check that: // - child op supports async scheduling // - there's a way to setup synchronization between async parent and // child - both child and parent should use the same type of device, // non-blocking synchronization between different device types is not // supported // If parent op is in another state (initialized or failed) then scheduling // is not possible bool CanSchedule(const Event& child_event, bool supports_async) const { return CanSchedule(type_, Query(), child_event.GetType(), supports_async); } static bool CanSchedule( int parent_type, EventStatus parent_status, int child_type, bool child_supports_async) { if (parent_status == EventStatus::EVENT_SUCCESS) { return true; } if (parent_status == EventStatus::EVENT_SCHEDULED) { return (parent_type == child_type) && child_supports_async; } return false; } int GetType() const { return type_; } void SetFinishedWithException(const char* err_msg = nullptr) { #ifdef CAFFE2_USE_EXCEPTION_PTR if (!caught_exception_) { caught_exception_ = std::current_exception(); typedef std::chrono::high_resolution_clock clock; exception_timestamp_ = clock::now().time_since_epoch() / std::chrono::milliseconds(1); } CAFFE_ENFORCE(caught_exception_, "No exception found"); #else VLOG(1) << "No support for exceptions in Event"; #endif // CAFFE2_USE_EXCEPTION_PTR if (err_msg) { SetFinished(err_msg); } else { SetFinished("Error happened during an operator run"); } } bool HasException() const { #ifdef CAFFE2_USE_EXCEPTION_PTR return (bool)caught_exception_; #else VLOG(1) << "No support for exceptions in Event"; return false; #endif // CAFFE2_USE_EXCEPTION_PTR } int64_t ExceptionTimestamp() const { #ifdef CAFFE2_USE_EXCEPTION_PTR return exception_timestamp_; #else VLOG(1) << "No support for exceptions in Event"; return 0; #endif // CAFFE2_USE_EXCEPTION_PTR } void RethrowException() const { #ifdef CAFFE2_USE_EXCEPTION_PTR if (caught_exception_) { std::rethrow_exception(caught_exception_); } #else VLOG(1) << "No support for exceptions in Event"; #endif // CAFFE2_USE_EXCEPTION_PTR } // event_ is going to be accessed by the EventCreate/Record/Wait/Finish // functions, but one should not use it outside the own Event functionalities. // In the future we may move it to a private member. std::shared_ptr event_; private: int type_; DeviceOption option_; #ifdef CAFFE2_USE_EXCEPTION_PTR std::exception_ptr caught_exception_; int64_t exception_timestamp_{}; #endif // CAFFE2_USE_EXCEPTION_PTR static EventCreateFunction event_creator_[MaxDeviceTypes]; static EventRecordFunction event_recorder_[MaxDeviceTypes]; static EventWaitFunction event_waiter_[MaxDeviceTypes] [MaxDeviceTypes]; static EventFinishFunction event_finisher_[MaxDeviceTypes]; static EventQueryFunction event_querier_[MaxDeviceTypes]; static EventErrorMessageFunction event_err_msg_getter_[MaxDeviceTypes]; static EventSetFinishedFunction event_finished_setter_[MaxDeviceTypes]; static EventResetFunction event_resetter_[MaxDeviceTypes]; static EventSetCallbackFunction event_callback_setter_[MaxDeviceTypes]; template friend struct EventCreateFunctionRegisterer; template friend struct EventRecordFunctionRegisterer; template friend struct EventWaitFunctionRegisterer; template friend struct EventFinishFunctionRegisterer; template friend struct EventQueryFunctionRegisterer; template friend struct EventErrorMessageFunctionRegisterer; template friend struct EventSetFinishedFunctionRegisterer; template friend struct EventSetCallbackFunctionRegisterer; template friend struct EventResetFunctionRegisterer; }; template struct EventCreateFunctionRegisterer { explicit EventCreateFunctionRegisterer(EventCreateFunction f) { auto d = TypeToProto(t); Event::event_creator_[d] = f; } }; #define REGISTER_EVENT_CREATE_FUNCTION(t, f) \ namespace { \ static EventCreateFunctionRegisterer g_event_create_##d(f); \ } template struct EventRecordFunctionRegisterer { explicit EventRecordFunctionRegisterer(EventRecordFunction f) { auto d = TypeToProto(t); Event::event_recorder_[d] = f; } }; #define REGISTER_EVENT_RECORD_FUNCTION(t, f) \ namespace { \ static EventRecordFunctionRegisterer g_event_record_##d(f); \ } template struct EventWaitFunctionRegisterer { explicit EventWaitFunctionRegisterer(EventWaitFunction f) { auto waiter_index = TypeToProto(waiter_type); auto event_index = TypeToProto(event_type); Event::event_waiter_[waiter_index][event_index] = f; } }; #define REGISTER_EVENT_WAIT_FUNCTION(w, d, f) \ namespace { \ static EventWaitFunctionRegisterer g_event_wait_##w##_##d(f); \ } template struct EventQueryFunctionRegisterer { explicit EventQueryFunctionRegisterer(EventQueryFunction f) { auto d = TypeToProto(t); Event::event_querier_[d] = f; } }; #define REGISTER_EVENT_QUERY_FUNCTION(t, f) \ namespace { \ static EventQueryFunctionRegisterer g_event_query_##d(f); \ } template struct EventErrorMessageFunctionRegisterer { explicit EventErrorMessageFunctionRegisterer(EventErrorMessageFunction f) { auto d = TypeToProto(t); Event::event_err_msg_getter_[d] = f; } }; #define REGISTER_EVENT_ERROR_MESSAGE_FUNCTION(t, f) \ namespace { \ static EventErrorMessageFunctionRegisterer g_event_err_msg_##d(f); \ } template struct EventSetFinishedFunctionRegisterer { explicit EventSetFinishedFunctionRegisterer(EventSetFinishedFunction f) { auto d = TypeToProto(t); Event::event_finished_setter_[d] = f; } }; #define REGISTER_EVENT_SET_FINISHED_FUNCTION(t, f) \ namespace { \ static EventSetFinishedFunctionRegisterer g_event_set_finished_##d(f); \ } template struct EventSetCallbackFunctionRegisterer { explicit EventSetCallbackFunctionRegisterer(EventSetCallbackFunction f) { auto d = TypeToProto(t); Event::event_callback_setter_[d] = f; } }; #define REGISTER_EVENT_SET_CALLBACK_FUNCTION(t, f) \ namespace { \ static EventSetCallbackFunctionRegisterer g_event_set_callback_##d(f); \ } template struct EventFinishFunctionRegisterer { explicit EventFinishFunctionRegisterer(EventFinishFunction f) { auto d = TypeToProto(t); Event::event_finisher_[d] = f; } }; #define REGISTER_EVENT_FINISH_FUNCTION(t, f) \ namespace { \ static EventFinishFunctionRegisterer g_event_finish_##d(f); \ } template struct EventResetFunctionRegisterer { explicit EventResetFunctionRegisterer(EventResetFunction f) { auto d = TypeToProto(t); Event::event_resetter_[d] = f; } }; #define REGISTER_EVENT_RESET_FUNCTION(t, f) \ namespace { \ static EventResetFunctionRegisterer g_event_reset_##d(f); \ } } // namespace caffe2 #endif // CAFFE2_CORE_EVENT_H_