#ifndef CAFFE2_UTILS_ZMQ_HELPER_H_ #define CAFFE2_UTILS_ZMQ_HELPER_H_ #include #include "caffe2/core/logging.h" namespace caffe2 { class ZmqContext { public: explicit ZmqContext(int io_threads) : ptr_(zmq_ctx_new()) { CAFFE_ENFORCE(ptr_ != nullptr, "Failed to create zmq context."); int rc = zmq_ctx_set(ptr_, ZMQ_IO_THREADS, io_threads); CAFFE_ENFORCE_EQ(rc, 0); rc = zmq_ctx_set(ptr_, ZMQ_MAX_SOCKETS, ZMQ_MAX_SOCKETS_DFLT); CAFFE_ENFORCE_EQ(rc, 0); } ~ZmqContext() { int rc = zmq_ctx_destroy(ptr_); CAFFE_ENFORCE_EQ(rc, 0); } void* ptr() { return ptr_; } private: void* ptr_; C10_DISABLE_COPY_AND_ASSIGN(ZmqContext); }; class ZmqMessage { public: ZmqMessage() { int rc = zmq_msg_init(&msg_); CAFFE_ENFORCE_EQ(rc, 0); } ~ZmqMessage() { int rc = zmq_msg_close(&msg_); CAFFE_ENFORCE_EQ(rc, 0); } zmq_msg_t* msg() { return &msg_; } void* data() { return zmq_msg_data(&msg_); } size_t size() { return zmq_msg_size(&msg_); } private: zmq_msg_t msg_; C10_DISABLE_COPY_AND_ASSIGN(ZmqMessage); }; class ZmqSocket { public: explicit ZmqSocket(int type) : context_(1), ptr_(zmq_socket(context_.ptr(), type)) { CAFFE_ENFORCE(ptr_ != nullptr, "Faild to create zmq socket."); } ~ZmqSocket() { int rc = zmq_close(ptr_); CAFFE_ENFORCE_EQ(rc, 0); } void Bind(const string& addr) { int rc = zmq_bind(ptr_, addr.c_str()); CAFFE_ENFORCE_EQ(rc, 0); } void Unbind(const string& addr) { int rc = zmq_unbind(ptr_, addr.c_str()); CAFFE_ENFORCE_EQ(rc, 0); } void Connect(const string& addr) { int rc = zmq_connect(ptr_, addr.c_str()); CAFFE_ENFORCE_EQ(rc, 0); } void Disconnect(const string& addr) { int rc = zmq_disconnect(ptr_, addr.c_str()); CAFFE_ENFORCE_EQ(rc, 0); } int Send(const string& msg, int flags) { int nbytes = zmq_send(ptr_, msg.c_str(), msg.size(), flags); if (nbytes) { return nbytes; } else if (zmq_errno() == EAGAIN) { return 0; } else { LOG(FATAL) << "Cannot send zmq message. Error number: " << zmq_errno(); return 0; } } int SendTillSuccess(const string& msg, int flags) { CAFFE_ENFORCE(msg.size(), "You cannot send an empty message."); int nbytes = 0; do { nbytes = Send(msg, flags); } while (nbytes == 0); return nbytes; } int Recv(ZmqMessage* msg) { int nbytes = zmq_msg_recv(msg->msg(), ptr_, 0); if (nbytes >= 0) { return nbytes; } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) { return 0; } else { LOG(FATAL) << "Cannot receive zmq message. Error number: " << zmq_errno(); return 0; } } int RecvTillSuccess(ZmqMessage* msg) { int nbytes = 0; do { nbytes = Recv(msg); } while (nbytes == 0); return nbytes; } private: ZmqContext context_; void* ptr_; }; } // namespace caffe2 #endif // CAFFE2_UTILS_ZMQ_HELPER_H_