#ifndef CAFFE2_OPERATORS_COUNTER_OPS_H
|
#define CAFFE2_OPERATORS_COUNTER_OPS_H
|
|
#include <atomic>
|
|
#include "caffe2/core/context.h"
|
#include "caffe2/core/logging.h"
|
#include "caffe2/core/operator.h"
|
|
namespace caffe2 {
|
template <typename T>
|
class CAFFE2_API Counter {
|
public:
|
explicit Counter(T count) : count_(count) {}
|
bool countDown() {
|
if (count_-- > 0) {
|
return false;
|
}
|
return true;
|
}
|
|
T countUp() {
|
return count_++;
|
}
|
|
T retrieve() const {
|
return count_.load();
|
}
|
|
T checkIfDone() const {
|
return (count_.load() <= 0);
|
}
|
|
T reset(T init_count) {
|
return count_.exchange(init_count);
|
}
|
|
private:
|
std::atomic<T> count_;
|
};
|
|
// TODO(jiayq): deprecate these ops & consolidate them with IterOp/AtomicIterOp
|
|
template <typename T, class Context>
|
class CreateCounterOp final : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
template <class... Args>
|
explicit CreateCounterOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...),
|
init_count_(this->template GetSingleArgument<T>("init_count", 0)) {
|
CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
|
}
|
|
bool RunOnDevice() override {
|
*this->template Output<std::unique_ptr<Counter<T>>>(0) =
|
std::unique_ptr<Counter<T>>(new Counter<T>(init_count_));
|
return true;
|
}
|
|
private:
|
T init_count_ = 0;
|
};
|
|
template <typename T, class Context>
|
class ResetCounterOp final : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
template <class... Args>
|
explicit ResetCounterOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...),
|
init_count_(this->template GetSingleArgument<T>("init_count", 0)) {
|
CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
|
}
|
|
bool RunOnDevice() override {
|
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
|
auto previous = counterPtr->reset(init_count_);
|
if (OutputSize() == 1) {
|
auto* output = Output(0);
|
output->Resize();
|
*output->template mutable_data<T>() = previous;
|
}
|
return true;
|
}
|
|
private:
|
T init_count_;
|
};
|
|
// Will always use TensorCPU regardless the Context
|
template <typename T, class Context>
|
class CountDownOp final : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
template <class... Args>
|
explicit CountDownOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...) {}
|
|
bool RunOnDevice() override {
|
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
|
auto* output = Output(0);
|
output->Resize(std::vector<int>{});
|
*output->template mutable_data<bool>() = counterPtr->countDown();
|
return true;
|
}
|
};
|
|
// Will always use TensorCPU regardless the Context
|
template <typename T, class Context>
|
class CheckCounterDoneOp final : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
template <class... Args>
|
explicit CheckCounterDoneOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...) {}
|
|
bool RunOnDevice() override {
|
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
|
auto* output = Output(0);
|
output->Resize(std::vector<int>{});
|
*output->template mutable_data<bool>() = counterPtr->checkIfDone();
|
return true;
|
}
|
};
|
|
// Will always use TensorCPU regardless the Context
|
template <typename T, class Context>
|
class CountUpOp final : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
template <class... Args>
|
explicit CountUpOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...) {}
|
|
bool RunOnDevice() override {
|
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
|
auto* output = Output(0);
|
output->Resize(std::vector<int>{});
|
*output->template mutable_data<T>() = counterPtr->countUp();
|
return true;
|
}
|
};
|
|
// Will always use TensorCPU regardless the Context
|
template <typename T, class Context>
|
class RetrieveCountOp final : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
template <class... Args>
|
explicit RetrieveCountOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...) {}
|
|
bool RunOnDevice() override {
|
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
|
auto* output = Output(0);
|
output->Resize(std::vector<int>{});
|
*output->template mutable_data<T>() = counterPtr->retrieve();
|
return true;
|
}
|
};
|
|
} // namespace caffe2
|
#endif // CAFFE2_OPERATORS_COUNTER_OPS_H_
|