#ifndef CAFFE2_UTILS_ZMQ_HELPER_H_
|
#define CAFFE2_UTILS_ZMQ_HELPER_H_
|
|
#include <zmq.h>
|
|
#include "caffe2/core/logging.h"
|
|
namespace caffe2 {
|
|
class ZmqContext {
|
public:
|
explicit ZmqContext(int io_threads) : ptr_(zmq_ctx_new()) {
|
CAFFE_ENFORCE(ptr_ != nullptr, "Failed to create zmq context.");
|
int rc = zmq_ctx_set(ptr_, ZMQ_IO_THREADS, io_threads);
|
CAFFE_ENFORCE_EQ(rc, 0);
|
rc = zmq_ctx_set(ptr_, ZMQ_MAX_SOCKETS, ZMQ_MAX_SOCKETS_DFLT);
|
CAFFE_ENFORCE_EQ(rc, 0);
|
}
|
~ZmqContext() {
|
int rc = zmq_ctx_destroy(ptr_);
|
CAFFE_ENFORCE_EQ(rc, 0);
|
}
|
|
void* ptr() { return ptr_; }
|
|
private:
|
void* ptr_;
|
|
C10_DISABLE_COPY_AND_ASSIGN(ZmqContext);
|
};
|
|
class ZmqMessage {
|
public:
|
ZmqMessage() {
|
int rc = zmq_msg_init(&msg_);
|
CAFFE_ENFORCE_EQ(rc, 0);
|
}
|
|
~ZmqMessage() {
|
int rc = zmq_msg_close(&msg_);
|
CAFFE_ENFORCE_EQ(rc, 0);
|
}
|
|
zmq_msg_t* msg() { return &msg_; }
|
|
void* data() { return zmq_msg_data(&msg_); }
|
size_t size() { return zmq_msg_size(&msg_); }
|
|
private:
|
zmq_msg_t msg_;
|
C10_DISABLE_COPY_AND_ASSIGN(ZmqMessage);
|
};
|
|
class ZmqSocket {
|
public:
|
explicit ZmqSocket(int type)
|
: context_(1), ptr_(zmq_socket(context_.ptr(), type)) {
|
CAFFE_ENFORCE(ptr_ != nullptr, "Faild to create zmq socket.");
|
}
|
|
~ZmqSocket() {
|
int rc = zmq_close(ptr_);
|
CAFFE_ENFORCE_EQ(rc, 0);
|
}
|
|
void Bind(const string& addr) {
|
int rc = zmq_bind(ptr_, addr.c_str());
|
CAFFE_ENFORCE_EQ(rc, 0);
|
}
|
|
void Unbind(const string& addr) {
|
int rc = zmq_unbind(ptr_, addr.c_str());
|
CAFFE_ENFORCE_EQ(rc, 0);
|
}
|
|
void Connect(const string& addr) {
|
int rc = zmq_connect(ptr_, addr.c_str());
|
CAFFE_ENFORCE_EQ(rc, 0);
|
}
|
|
void Disconnect(const string& addr) {
|
int rc = zmq_disconnect(ptr_, addr.c_str());
|
CAFFE_ENFORCE_EQ(rc, 0);
|
}
|
|
int Send(const string& msg, int flags) {
|
int nbytes = zmq_send(ptr_, msg.c_str(), msg.size(), flags);
|
if (nbytes) {
|
return nbytes;
|
} else if (zmq_errno() == EAGAIN) {
|
return 0;
|
} else {
|
LOG(FATAL) << "Cannot send zmq message. Error number: "
|
<< zmq_errno();
|
return 0;
|
}
|
}
|
|
int SendTillSuccess(const string& msg, int flags) {
|
CAFFE_ENFORCE(msg.size(), "You cannot send an empty message.");
|
int nbytes = 0;
|
do {
|
nbytes = Send(msg, flags);
|
} while (nbytes == 0);
|
return nbytes;
|
}
|
|
int Recv(ZmqMessage* msg) {
|
int nbytes = zmq_msg_recv(msg->msg(), ptr_, 0);
|
if (nbytes >= 0) {
|
return nbytes;
|
} else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) {
|
return 0;
|
} else {
|
LOG(FATAL) << "Cannot receive zmq message. Error number: "
|
<< zmq_errno();
|
return 0;
|
}
|
}
|
|
int RecvTillSuccess(ZmqMessage* msg) {
|
int nbytes = 0;
|
do {
|
nbytes = Recv(msg);
|
} while (nbytes == 0);
|
return nbytes;
|
}
|
|
private:
|
ZmqContext context_;
|
void* ptr_;
|
};
|
|
} // namespace caffe2
|
|
|
#endif // CAFFE2_UTILS_ZMQ_HELPER_H_
|