#ifndef CAFFE2_SGD_ITER_OP_H_
|
#define CAFFE2_SGD_ITER_OP_H_
|
|
#include <limits>
|
#include <mutex>
|
|
#include "caffe2/core/blob_serialization.h"
|
#include "caffe2/core/context.h"
|
#include "caffe2/core/operator.h"
|
#include "caffe2/core/stats.h"
|
|
namespace caffe2 {
|
|
inline void IncrementIter(TensorCPU* output) {
|
CAFFE_ENFORCE_EQ(
|
output->numel(),
|
1,
|
"The output of IterOp exists, but not of the right size.");
|
int64_t* iter = output->template mutable_data<int64_t>();
|
CAFFE_ENFORCE(*iter >= 0, "Previous iteration number is negative.");
|
CAFFE_ENFORCE(
|
*iter < std::numeric_limits<int64_t>::max(), "Overflow will happen!");
|
(*iter)++;
|
}
|
|
// IterOp runs an iteration counter. I cannot think of a case where we would
|
// need to access the iter variable on device, so this will always produce a
|
// tensor on the CPU side. If the blob already exists and is a tensor<int64_t>
|
// object, we will simply increment it (this emulates the case when we want to
|
// resume training). Otherwise we will have the iter starting with 0.
|
template <class Context>
|
class IterOp final : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
IterOp(const OperatorDef& operator_def, Workspace* ws)
|
: Operator<Context>(operator_def, ws) {}
|
|
bool RunOnDevice() override {
|
if (InputSize() == 0) {
|
LOG(INFO) << "[Input size is zero]";
|
if (!OperatorBase::OutputIsTensorType(0, CPU)) {
|
// This is the first run; set the iter to start with 0.
|
LOG(ERROR) << "You are using an old definition of IterOp that will "
|
"be deprecated soon. More specifically, IterOp now "
|
"requires an explicit in-place input and output.";
|
|
VLOG(1) << "Initializing iter counter.";
|
auto* output = OperatorBase::OutputTensor(
|
0, {1}, at::dtype<int64_t>().device(CPU));
|
output->template mutable_data<int64_t>()[0] = 0;
|
}
|
}
|
IncrementIter(OperatorBase::Output<Tensor>(0, CPU));
|
return true;
|
}
|
};
|
|
template <class Context>
|
class AtomicIterOp final : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
AtomicIterOp(const OperatorDef& operator_def, Workspace* ws)
|
: Operator<Context>(operator_def, ws),
|
stats_(std::string("atomic_iter/stats/") + operator_def.input(1)) {}
|
|
bool RunOnDevice() override {
|
auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(0);
|
std::lock_guard<std::mutex> lg(*mutex);
|
IncrementIter(OperatorBase::Output<Tensor>(0, CPU));
|
CAFFE_EVENT(stats_, num_iter);
|
return true;
|
}
|
|
private:
|
struct AtomicIterOpStats {
|
CAFFE_STAT_CTOR(AtomicIterOpStats);
|
CAFFE_EXPORTED_STAT(num_iter);
|
} stats_;
|
};
|
|
class MutexSerializer : public BlobSerializerBase {
|
public:
|
/**
|
* Serializes a std::unique_ptr<std::mutex>. Note that this blob has to
|
* contain std::unique_ptr<std::mutex>, otherwise this function produces a
|
* fatal error.
|
*/
|
void Serialize(
|
const void* pointer,
|
TypeMeta typeMeta,
|
const string& name,
|
BlobSerializerBase::SerializationAcceptor acceptor) override;
|
};
|
|
class MutexDeserializer : public BlobDeserializerBase {
|
public:
|
void Deserialize(const BlobProto& proto, Blob* blob) override;
|
};
|
|
} // namespace caffe2
|
|
#endif // CAFFE2_SGD_ITER_OP_H_
|