#pragma once
|
|
#include <memory>
|
#include <unordered_set>
|
|
#include "caffe2/core/logging.h"
|
|
namespace caffe2 {
|
|
/**
|
* Use this to implement a Observer using the Observer Pattern template.
|
*/
|
|
template <class T>
|
class ObserverBase {
|
public:
|
explicit ObserverBase(T* subject) : subject_(subject) {}
|
|
virtual void Start() {}
|
virtual void Stop() {}
|
|
virtual std::string debugInfo() {
|
return "Not implemented.";
|
}
|
|
virtual ~ObserverBase() noexcept {};
|
|
T* subject() const {
|
return subject_;
|
}
|
|
virtual std::unique_ptr<ObserverBase<T>> rnnCopy(T* subject, int rnn_order)
|
const {
|
return nullptr;
|
};
|
|
protected:
|
T* subject_;
|
};
|
|
/**
|
* Inherit to make your class observable.
|
*/
|
template <class T>
|
class Observable {
|
public:
|
Observable() = default;
|
|
Observable(Observable&&) = default;
|
Observable& operator =(Observable&&) = default;
|
|
virtual ~Observable() = default;
|
|
C10_DISABLE_COPY_AND_ASSIGN(Observable);
|
|
using Observer = ObserverBase<T>;
|
|
/* Returns a reference to the observer after addition. */
|
const Observer* AttachObserver(std::unique_ptr<Observer> observer) {
|
CAFFE_ENFORCE(observer, "Couldn't attach a null observer.");
|
std::unordered_set<const Observer*> observers;
|
for (auto& ob : observers_list_) {
|
observers.insert(ob.get());
|
}
|
|
const auto* observer_ptr = observer.get();
|
if (observers.count(observer_ptr)) {
|
return observer_ptr;
|
}
|
observers_list_.push_back(std::move(observer));
|
UpdateCache();
|
|
return observer_ptr;
|
}
|
|
/**
|
* Returns a unique_ptr to the removed observer. If not found, return a
|
* nullptr
|
*/
|
std::unique_ptr<Observer> DetachObserver(const Observer* observer_ptr) {
|
for (auto it = observers_list_.begin(); it != observers_list_.end(); ++it) {
|
if (it->get() == observer_ptr) {
|
auto res = std::move(*it);
|
observers_list_.erase(it);
|
UpdateCache();
|
return res;
|
}
|
}
|
return nullptr;
|
}
|
|
virtual size_t NumObservers() {
|
return num_observers_;
|
}
|
|
private:
|
inline static void StartObserver(Observer* observer) {
|
try {
|
observer->Start();
|
} catch (const std::exception& e) {
|
LOG(ERROR) << "Exception from observer: " << e.what();
|
} catch (...) {
|
LOG(ERROR) << "Exception from observer: unknown";
|
}
|
}
|
|
inline static void StopObserver(Observer* observer) {
|
try {
|
observer->Stop();
|
} catch (const std::exception& e) {
|
LOG(ERROR) << "Exception from observer: " << e.what();
|
} catch (...) {
|
LOG(ERROR) << "Exception from observer: unknown";
|
}
|
}
|
|
void UpdateCache() {
|
num_observers_ = observers_list_.size();
|
if (num_observers_ != 1) {
|
// we cannot take advantage of the cache
|
return;
|
}
|
observer_cache_ = observers_list_[0].get();
|
}
|
|
public:
|
void StartAllObservers() {
|
// do not access observers_list_ unless necessary
|
if (num_observers_ == 0) {
|
return;
|
} else if (num_observers_ == 1) {
|
StartObserver(observer_cache_);
|
} else {
|
for (auto& observer : observers_list_) {
|
StartObserver(observer.get());
|
}
|
}
|
}
|
|
void StopAllObservers() {
|
// do not access observers_list_ unless necessary
|
if (num_observers_ == 0) {
|
return;
|
} else if (num_observers_ == 1) {
|
StopObserver(observer_cache_);
|
} else {
|
for (auto& observer : observers_list_) {
|
StopObserver(observer.get());
|
}
|
}
|
}
|
|
private:
|
// an on-stack cache for fast iteration;
|
// ideally, inside StartAllObservers and StopAllObservers,
|
// we should never access observers_list_
|
Observer* observer_cache_;
|
size_t num_observers_ = 0;
|
|
protected:
|
std::vector<std::unique_ptr<Observer>> observers_list_;
|
};
|
|
} // namespace caffe2
|