#pragma once #include #include #include namespace at { // Thread local debug information is propagated across the forward // (including async fork tasks) and backward passes and is supposed // to be utilized by the user's code to pass extra information from // the higher layers (e.g. model id) down to the operator callbacks // (e.g. used for logging) class CAFFE2_API ThreadLocalDebugInfoBase { public: ThreadLocalDebugInfoBase() {} virtual ~ThreadLocalDebugInfoBase() {} }; CAFFE2_API std::shared_ptr getThreadLocalDebugInfo() noexcept; // Sets thread local debug information, returns the previously set // debug information CAFFE2_API std::shared_ptr setThreadLocalDebugInfo( std::shared_ptr info) noexcept; class CAFFE2_API DebugInfoGuard { public: explicit DebugInfoGuard( std::shared_ptr info) { prev_info_ = setThreadLocalDebugInfo(std::move(info)); } ~DebugInfoGuard() { setThreadLocalDebugInfo(std::move(prev_info_)); } private: std::shared_ptr prev_info_; }; } // namespace at