#pragma once #include #include #include "caffe2/core/logging.h" namespace caffe2 { /** * Use this to implement a Observer using the Observer Pattern template. */ template class ObserverBase { public: explicit ObserverBase(T* subject) : subject_(subject) {} virtual void Start() {} virtual void Stop() {} virtual std::string debugInfo() { return "Not implemented."; } virtual ~ObserverBase() noexcept {}; T* subject() const { return subject_; } virtual std::unique_ptr> rnnCopy(T* subject, int rnn_order) const { return nullptr; }; protected: T* subject_; }; /** * Inherit to make your class observable. */ template class Observable { public: Observable() = default; Observable(Observable&&) = default; Observable& operator =(Observable&&) = default; virtual ~Observable() = default; C10_DISABLE_COPY_AND_ASSIGN(Observable); using Observer = ObserverBase; /* Returns a reference to the observer after addition. */ const Observer* AttachObserver(std::unique_ptr observer) { CAFFE_ENFORCE(observer, "Couldn't attach a null observer."); std::unordered_set observers; for (auto& ob : observers_list_) { observers.insert(ob.get()); } const auto* observer_ptr = observer.get(); if (observers.count(observer_ptr)) { return observer_ptr; } observers_list_.push_back(std::move(observer)); UpdateCache(); return observer_ptr; } /** * Returns a unique_ptr to the removed observer. If not found, return a * nullptr */ std::unique_ptr DetachObserver(const Observer* observer_ptr) { for (auto it = observers_list_.begin(); it != observers_list_.end(); ++it) { if (it->get() == observer_ptr) { auto res = std::move(*it); observers_list_.erase(it); UpdateCache(); return res; } } return nullptr; } virtual size_t NumObservers() { return num_observers_; } private: inline static void StartObserver(Observer* observer) { try { observer->Start(); } catch (const std::exception& e) { LOG(ERROR) << "Exception from observer: " << e.what(); } catch (...) { LOG(ERROR) << "Exception from observer: unknown"; } } inline static void StopObserver(Observer* observer) { try { observer->Stop(); } catch (const std::exception& e) { LOG(ERROR) << "Exception from observer: " << e.what(); } catch (...) { LOG(ERROR) << "Exception from observer: unknown"; } } void UpdateCache() { num_observers_ = observers_list_.size(); if (num_observers_ != 1) { // we cannot take advantage of the cache return; } observer_cache_ = observers_list_[0].get(); } public: void StartAllObservers() { // do not access observers_list_ unless necessary if (num_observers_ == 0) { return; } else if (num_observers_ == 1) { StartObserver(observer_cache_); } else { for (auto& observer : observers_list_) { StartObserver(observer.get()); } } } void StopAllObservers() { // do not access observers_list_ unless necessary if (num_observers_ == 0) { return; } else if (num_observers_ == 1) { StopObserver(observer_cache_); } else { for (auto& observer : observers_list_) { StopObserver(observer.get()); } } } private: // an on-stack cache for fast iteration; // ideally, inside StartAllObservers and StopAllObservers, // we should never access observers_list_ Observer* observer_cache_; size_t num_observers_ = 0; protected: std::vector> observers_list_; }; } // namespace caffe2