#ifndef CAFFE2_OPERATORS_COUNTER_OPS_H #define CAFFE2_OPERATORS_COUNTER_OPS_H #include #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" namespace caffe2 { template class CAFFE2_API Counter { public: explicit Counter(T count) : count_(count) {} bool countDown() { if (count_-- > 0) { return false; } return true; } T countUp() { return count_++; } T retrieve() const { return count_.load(); } T checkIfDone() const { return (count_.load() <= 0); } T reset(T init_count) { return count_.exchange(init_count); } private: std::atomic count_; }; // TODO(jiayq): deprecate these ops & consolidate them with IterOp/AtomicIterOp template class CreateCounterOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit CreateCounterOp(Args&&... args) : Operator(std::forward(args)...), init_count_(this->template GetSingleArgument("init_count", 0)) { CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted."); } bool RunOnDevice() override { *this->template Output>>(0) = std::unique_ptr>(new Counter(init_count_)); return true; } private: T init_count_ = 0; }; template class ResetCounterOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit ResetCounterOp(Args&&... args) : Operator(std::forward(args)...), init_count_(this->template GetSingleArgument("init_count", 0)) { CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted."); } bool RunOnDevice() override { auto& counterPtr = this->template Input>>(0); auto previous = counterPtr->reset(init_count_); if (OutputSize() == 1) { auto* output = Output(0); output->Resize(); *output->template mutable_data() = previous; } return true; } private: T init_count_; }; // Will always use TensorCPU regardless the Context template class CountDownOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit CountDownOp(Args&&... args) : Operator(std::forward(args)...) {} bool RunOnDevice() override { auto& counterPtr = this->template Input>>(0); auto* output = Output(0); output->Resize(std::vector{}); *output->template mutable_data() = counterPtr->countDown(); return true; } }; // Will always use TensorCPU regardless the Context template class CheckCounterDoneOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit CheckCounterDoneOp(Args&&... args) : Operator(std::forward(args)...) {} bool RunOnDevice() override { auto& counterPtr = this->template Input>>(0); auto* output = Output(0); output->Resize(std::vector{}); *output->template mutable_data() = counterPtr->checkIfDone(); return true; } }; // Will always use TensorCPU regardless the Context template class CountUpOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit CountUpOp(Args&&... args) : Operator(std::forward(args)...) {} bool RunOnDevice() override { auto& counterPtr = this->template Input>>(0); auto* output = Output(0); output->Resize(std::vector{}); *output->template mutable_data() = counterPtr->countUp(); return true; } }; // Will always use TensorCPU regardless the Context template class RetrieveCountOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit RetrieveCountOp(Args&&... args) : Operator(std::forward(args)...) {} bool RunOnDevice() override { auto& counterPtr = this->template Input>>(0); auto* output = Output(0); output->Resize(std::vector{}); *output->template mutable_data() = counterPtr->retrieve(); return true; } }; } // namespace caffe2 #endif // CAFFE2_OPERATORS_COUNTER_OPS_H_