#include "shm_socket.h" #include "hashtable.h" #include "logger_factory.h" #include #include "bus_error.h" static Logger *logger = LoggerFactory::getLogger(); static void print_msg(char *head, shm_msg_t &msg) { // err_msg(0, "%s: key=%d, type=%d\n", head, msg.key, msg.type); } static pthread_once_t _once_ = PTHREAD_ONCE_INIT; static pthread_key_t _tmp_recv_socket_key_; static void *_server_run_msg_rev(void *_socket); static void *_client_run_msg_rev(void *_socket); static int _shm_close_dgram_socket(shm_socket_t *socket); static int _shm_close_stream_socket(shm_socket_t *socket, bool notifyRemote); static void _destrory_tmp_recv_socket_(void *tmp_socket); static void _create_tmp_recv_socket_key(void); // 检查key是否已经被使用,是返回0, 否返回1 static inline int _shm_socket_check_key(shm_socket_t *socket) { void *tmp_ptr = mm_get_by_key(socket->key); if (tmp_ptr!= NULL && tmp_ptr != (void *)1 && !socket->force_bind ) { bus_errno = ESHM_BUS_KEY_INUSED; logger->error("%s. key = %d ", bus_strerror(bus_errno), socket->key); return 0; } return 1; } SHMQueue *_attach_remote_queue(int key); size_t shm_socket_remove_keys(int keys[], size_t length) { return SHMQueue::remove_queues(keys, length); } shm_socket_t *shm_open_socket(shm_socket_type_t socket_type) { int s, type; pthread_mutexattr_t mtxAttr; logger->debug("shm_open_socket\n"); shm_socket_t *socket = (shm_socket_t *)calloc(1, sizeof(shm_socket_t)); socket->socket_type = socket_type; socket->key = -1; socket->force_bind = false; socket->dispatch_thread = 0; socket->status = SHM_CONN_CLOSED; s = pthread_mutexattr_init(&mtxAttr); if (s != 0) err_exit(s, "pthread_mutexattr_init"); s = pthread_mutexattr_settype(&mtxAttr, PTHREAD_MUTEX_ERRORCHECK); if (s != 0) err_exit(s, "pthread_mutexattr_settype"); s = pthread_mutex_init(&(socket->mutex), &mtxAttr); if (s != 0) err_exit(s, "pthread_mutex_init"); s = pthread_mutexattr_destroy(&mtxAttr); if (s != 0) err_exit(s, "pthread_mutexattr_destroy"); return socket; } int shm_close_socket(shm_socket_t *socket) { int ret, s; logger->debug("shm_close_socket\n"); switch (socket->socket_type) { case SHM_SOCKET_STREAM: ret = _shm_close_stream_socket(socket, true); break; case SHM_SOCKET_DGRAM: ret = _shm_close_dgram_socket(socket); break; default: break; } s = pthread_mutex_destroy(&(socket->mutex) ); if(s != 0) { err_exit(s, "shm_close_socket"); } free(socket); return ret; } // int shm_close_socket(shm_socket_t *socket) { // // _destrory_tmp_recv_socket_((shm_socket_t *)pthread_getspecific(_tmp_recv_socket_key_)); // return shm_close_socket(socket);; // } int shm_socket_bind(shm_socket_t *socket, int key) { socket->key = key; return 0; } int shm_socket_force_bind(shm_socket_t *socket, int key) { socket->force_bind = true; socket->key = key; return 0; } int shm_listen(shm_socket_t *socket) { if (socket->socket_type != SHM_SOCKET_STREAM) { logger->error("can not invoke shm_listen method with a socket which is not a " "SHM_SOCKET_STREAM socket"); exit(1); } int key; hashtable_t *hashtable = mm_get_hashtable(); if (socket->key == -1) { key = hashtable_alloc_key(hashtable); socket->key = key; } else { if(!_shm_socket_check_key(socket)) { bus_errno = ESHM_BUS_KEY_INUSED; return ESHM_BUS_KEY_INUSED; } } socket->queue = new SHMQueue(socket->key, 16); socket->acceptQueue = new LockFreeQueue(16); socket->clientSocketMap = new std::map; socket->status = SHM_CONN_LISTEN; pthread_create(&(socket->dispatch_thread), NULL, _server_run_msg_rev, (void *)socket); return 0; } /** * 接受客户端建立新连接的请求 * */ shm_socket_t *shm_accept(shm_socket_t *socket) { if (socket->socket_type != SHM_SOCKET_STREAM) { logger->error("can not invoke shm_accept method with a socket which is not a " "SHM_SOCKET_STREAM socket"); exit(1); } hashtable_t *hashtable = mm_get_hashtable(); int client_key; shm_socket_t *client_socket; shm_msg_t src; if (socket->acceptQueue->pop(src) == 0) { // print_msg("===accept:", src); client_key = src.key; // client_socket = (shm_socket_t *)malloc(sizeof(shm_socket_t)); client_socket = shm_open_socket(socket->socket_type); client_socket->key = socket->key; // client_socket->queue= socket->queue; //初始化消息queue client_socket->messageQueue = new LockFreeQueue(16); //连接到对方queue client_socket->remoteQueue = _attach_remote_queue(client_key); socket->clientSocketMap->insert({client_key, client_socket}); /* * shm_accept 用户执行的方法 * 与_server_run_msg_rev在两个不同的限制工作,accept要保证在客户的发送消息之前完成资源的准备工作,以避免出现竞态问题 */ //发送open_reply,回应客户端的connect请求 struct timespec timeout = {1, 0}; shm_msg_t msg; msg.key = socket->key; msg.size = 0; msg.type = SHM_SOCKET_OPEN_REPLY; if (client_socket->remoteQueue->push_timeout(msg, &timeout) == 0) { client_socket->status = SHM_CONN_ESTABLISHED; return client_socket; } else { logger->error( "shm_accept: 发送open_reply失败"); return NULL; } } else { err_exit(errno, "shm_accept"); } return NULL; } /** * @return 0成功. 其他值失败 */ int shm_connect(shm_socket_t *socket, int key) { if (socket->socket_type != SHM_SOCKET_STREAM) { logger->error( "can not invoke shm_connect method with a socket which is not " "a SHM_SOCKET_STREAM socket"); exit(1); } hashtable_t *hashtable = mm_get_hashtable(); if (hashtable_get(hashtable, key) == NULL) { logger->error("shm_connect:connect at key %d failed!", key); return -1; } if (socket->key == -1) { socket->key = hashtable_alloc_key(hashtable); } else { if(!_shm_socket_check_key(socket)) { bus_errno = ESHM_BUS_KEY_INUSED; return ESHM_BUS_KEY_INUSED; } } socket->queue = new SHMQueue(socket->key, 16); if ((socket->remoteQueue = _attach_remote_queue(key)) == NULL) { logger->error("connect to %d failted", key); return -1; } socket->messageQueue = new LockFreeQueue(16); //发送open请求 struct timespec timeout = {1, 0}; shm_msg_t msg; msg.key = socket->key; msg.size = 0; msg.type = SHM_SOCKET_OPEN; socket->remoteQueue->push_timeout(msg, &timeout); //接受open reply if (socket->queue->pop(msg) == 0) { // 在这里server端已经准备好接受客户端发送请求了,完成与服务端的连接 if (msg.type == SHM_SOCKET_OPEN_REPLY) { socket->status = SHM_CONN_ESTABLISHED; pthread_create(&(socket->dispatch_thread), NULL, _client_run_msg_rev, (void *)socket); } else { logger->error( "shm_connect: 不匹配的应答信息!"); exit(1); } } else { logger->error( "connect failted!"); return -1; } return 0; } int shm_send(shm_socket_t *socket, const void *buf, const int size) { if (socket->socket_type != SHM_SOCKET_STREAM) { logger->error("shm_socket.shm_send: can not invoke shm_send method with a socket which is not a " "SHM_SOCKET_STREAM socket"); exit(1); } hashtable_t *hashtable = mm_get_hashtable(); if(socket->remoteQueue == NULL) { err_msg(errno, "当前客户端无连接!"); return -1; } shm_msg_t dest; dest.type = SHM_COMMON_MSG; dest.key = socket->key; dest.size = size; dest.buf = mm_malloc(size); memcpy(dest.buf, buf, size); if (socket->remoteQueue->push(dest) == 0) { return 0; } else { logger->error(errno, "connection has been closed!"); return -1; } } int shm_recv(shm_socket_t *socket, void **buf, int *size) { if (socket->socket_type != SHM_SOCKET_STREAM) { logger->error( "shm_socket.shm_recv: can not invoke shm_recv method in a %d type socket which is " "not a SHM_SOCKET_STREAM socket ", socket->socket_type); exit(1); } shm_msg_t src; if (socket->messageQueue->pop(src) == 0) { void *_buf = malloc(src.size); memcpy(_buf, src.buf, src.size); *buf = _buf; *size = src.size; mm_free(src.buf); return 0; } else { return -1; } } // 短连接方式发送 int shm_sendto(shm_socket_t *socket, const void *buf, const int size, const int key, const struct timespec *timeout, const int flags) { int s; bool rv; if (socket->socket_type != SHM_SOCKET_DGRAM) { logger->error( "shm_socket.shm_sendto: Can't invoke shm_sendto method in a %d type socket which is " "not a SHM_SOCKET_DGRAM socket ", socket->socket_type); exit(0); } hashtable_t *hashtable = mm_get_hashtable(); if ((s = pthread_mutex_lock(&(socket->mutex))) != 0) err_exit(s, "shm_sendto : pthread_mutex_lock"); if (socket->queue == NULL) { if (socket->key == -1) { socket->key = hashtable_alloc_key(hashtable); } else { if(!_shm_socket_check_key(socket)) { bus_errno = ESHM_BUS_KEY_INUSED; return ESHM_BUS_KEY_INUSED; } } socket->queue = new SHMQueue(socket->key, 16); } if ((s = pthread_mutex_unlock(&(socket->mutex))) != 0) err_exit(s, "shm_sendto : pthread_mutex_unlock"); // if (key == socket->key) { // logger->error( "can not send to your self!"); // return -1; // } SHMQueue *remoteQueue; if ((remoteQueue = _attach_remote_queue(key)) == NULL) { bus_errno = EBUS_CLOSED; logger->error("sendto key %d failed, %s", key, bus_strerror(bus_errno)); return EBUS_CLOSED; } shm_msg_t dest; dest.type = SHM_COMMON_MSG; dest.key = socket->key; dest.size = size; dest.buf = mm_malloc(size); memcpy(dest.buf, buf, size); if(flags & SHM_MSG_NOWAIT != 0) { rv = remoteQueue->push_nowait(dest); } else if(timeout != NULL) { rv = remoteQueue->push_timeout(dest, timeout); } else { rv = remoteQueue->push(dest); } if (rv == 0) { // printf("shm_sendto push after\n"); delete remoteQueue; return 0; } else { delete remoteQueue; mm_free(dest.buf); if(rv == EBUS_TIMEOUT) { bus_errno = EBUS_TIMEOUT; logger->error(errno, "sendto key %d failed, %s", key, bus_strerror(EBUS_TIMEOUT)); return EBUS_TIMEOUT; } else { //logger->error(errno, "sendto key %d failed!", key); return -1; } } } // 短连接方式接受 int shm_recvfrom(shm_socket_t *socket, void **buf, int *size, int *key, struct timespec *timeout, int flags) { int s; bool rv; if (socket->socket_type != SHM_SOCKET_DGRAM) { logger->error("shm_socket.shm_recvfrom: Can't invoke shm_recvfrom method in a %d type socket which " "is not a SHM_SOCKET_DGRAM socket ", socket->socket_type); exit(1); } hashtable_t *hashtable = mm_get_hashtable(); if ((s = pthread_mutex_lock(&(socket->mutex))) != 0) err_exit(s, "shm_recvfrom : pthread_mutex_lock"); if (socket->queue == NULL) { if (socket->key == -1) { socket->key = hashtable_alloc_key(hashtable); } else { if(!_shm_socket_check_key(socket)) { bus_errno = ESHM_BUS_KEY_INUSED; return ESHM_BUS_KEY_INUSED; } } socket->queue = new SHMQueue(socket->key, 16); } if ((s = pthread_mutex_unlock(&(socket->mutex))) != 0) err_exit(s, "shm_recvfrom : pthread_mutex_unlock"); shm_msg_t src; if(flags & SHM_MSG_NOWAIT != 0) { rv = socket->queue->pop_nowait(src); } else if(timeout != NULL) { rv = socket->queue->pop_timeout(src, timeout); } else { rv = socket->queue->pop(src); } if (rv == 0) { if(buf != NULL) { void *_buf = malloc(src.size); memcpy(_buf, src.buf, src.size); *buf = _buf; } if(size != NULL) *size = src.size; if(key != NULL) *key = src.key; mm_free(src.buf); return 0; } else { return -1; } } /* Free thread-specific data buffer */ static void _destrory_tmp_recv_socket_(void *tmp_socket) { int rv; if(tmp_socket == NULL) return; logger->debug("%d destroy tmp socket\n", pthread_self()); shm_close_socket((shm_socket_t *)tmp_socket); rv = pthread_setspecific(_tmp_recv_socket_key_, NULL); if ( rv != 0) { logger->error(rv, "shm_sendandrecv : pthread_setspecific"); exit(1); } } /* One-time key creation function */ static void _create_tmp_recv_socket_key(void) { int s; /* Allocate a unique thread-specific data key and save the address of the destructor for thread-specific data buffers */ s = pthread_key_create(&_tmp_recv_socket_key_, _destrory_tmp_recv_socket_); //s = pthread_key_create(&_tmp_recv_socket_key_, NULL); if (s != 0) { logger->error(s, "pthread_key_create"); abort(); /* dump core and terminate */ exit(1); } } // use thread local int _shm_sendandrecv_thread_local(shm_socket_t *socket, const void *send_buf, const int send_size, const int send_key, void **recv_buf, int *recv_size, struct timespec *timeout, int flags) { int recv_key; int rv; // 用thread local 保证每个线程用一个独占的socket接受对方返回的信息 shm_socket_t *tmp_socket; if (socket->socket_type != SHM_SOCKET_DGRAM) { logger->error( "shm_socket.shm_sendandrecv: Can't invoke shm_sendandrecv method in a %d type socket " "which is not a SHM_SOCKET_DGRAM socket ", socket->socket_type); exit(1); } rv = pthread_once(&_once_, _create_tmp_recv_socket_key); if (rv != 0) { logger->error(rv, "shm_sendandrecv pthread_once"); exit(1); } tmp_socket = (shm_socket_t *)pthread_getspecific(_tmp_recv_socket_key_); if (tmp_socket == NULL) { /* If first call from this thread, allocate buffer for thread, and save its location */ logger->debug("%d create tmp socket\n", pthread_self() ); tmp_socket = shm_open_socket(SHM_SOCKET_DGRAM); rv = pthread_setspecific(_tmp_recv_socket_key_, tmp_socket); if ( rv != 0) { logger->error(rv, "shm_sendandrecv : pthread_setspecific"); exit(1); } } if ((rv = shm_sendto(tmp_socket, send_buf, send_size, send_key, timeout, flags)) == 0) { rv = shm_recvfrom(tmp_socket, recv_buf, recv_size, &recv_key, timeout, flags); return rv; } else { return rv; } return -1; } int _shm_sendandrecv_alloc_new(shm_socket_t *socket, const void *send_buf, const int send_size, const int send_key, void **recv_buf, int *recv_size, struct timespec *timeout, int flags) { int recv_key; int rv; // 用thread local 保证每个线程用一个独占的socket接受对方返回的信息 shm_socket_t *tmp_socket; if (socket->socket_type != SHM_SOCKET_DGRAM) { logger->error( "shm_socket.shm_sendandrecv: Can't invoke shm_sendandrecv method in a %d type socket " "which is not a SHM_SOCKET_DGRAM socket ", socket->socket_type); exit(1); } /* If first call from this thread, allocate buffer for thread, and save its location */ // logger->debug("%d create tmp socket\n", pthread_self() ); tmp_socket = shm_open_socket(SHM_SOCKET_DGRAM); if ((rv = shm_sendto(tmp_socket, send_buf, send_size, send_key, timeout, flags)) == 0) { rv = shm_recvfrom(tmp_socket, recv_buf, recv_size, &recv_key, timeout, flags); } shm_close_socket(tmp_socket); return rv; } int shm_sendandrecv_unsafe(shm_socket_t *socket, const void *send_buf, const int send_size, const int send_key, void **recv_buf, int *recv_size, struct timespec *timeout, int flags) { if (socket->socket_type != SHM_SOCKET_DGRAM) { logger->error( "shm_socket.shm_sendandrecv_unsafe : Can't invoke shm_sendandrecv method in a %d type socket " "which is not a SHM_SOCKET_DGRAM socket ", socket->socket_type); exit(1); } int recv_key; int rv; if ((rv = shm_sendto(socket, send_buf, send_size, send_key, timeout, flags)) == 0) { rv = shm_recvfrom(socket, recv_buf, recv_size, &recv_key, timeout, flags); return rv; } else { return rv; } return -1; } int shm_sendandrecv(shm_socket_t *socket, const void *send_buf, const int send_size, const int send_key, void **recv_buf, int *recv_size, struct timespec *timeout, int flags) { return _shm_sendandrecv_thread_local(socket, send_buf, send_size, send_key,recv_buf, recv_size, timeout, flags); } // ============================================================================================================ /** * 绑定key到队列,但是并不会创建队列。如果没有对应指定key的队列提示错误并退出 */ SHMQueue *_attach_remote_queue(int key) { hashtable_t *hashtable = mm_get_hashtable(); if (hashtable_get(hashtable, key) == NULL) { //logger->error("shm_socket._remote_queue_attach:connet at key %d failed!", key); return NULL; } SHMQueue *queue = new SHMQueue(key, 0); return queue; } void _server_close_conn_to_client(shm_socket_t *socket, int key) { shm_socket_t *client_socket; std::map::iterator iter = socket->clientSocketMap->find(key); if (iter != socket->clientSocketMap->end()) { client_socket = iter->second; free((void *)client_socket); socket->clientSocketMap->erase(iter); } } /** * server端各种类型消息()在这里进程分拣 */ void *_server_run_msg_rev(void *_socket) { pthread_detach(pthread_self()); shm_socket_t *socket = (shm_socket_t *)_socket; struct timespec timeout = {1, 0}; shm_msg_t src; shm_socket_t *client_socket; std::map::iterator iter; while (socket->queue->pop(src) == 0) { switch (src.type) { case SHM_SOCKET_OPEN: socket->acceptQueue->push_timeout(src, &timeout); break; case SHM_SOCKET_CLOSE: _server_close_conn_to_client(socket, src.key); break; case SHM_COMMON_MSG: iter = socket->clientSocketMap->find(src.key); if (iter != socket->clientSocketMap->end()) { client_socket = iter->second; // print_msg("_server_run_msg_rev push before", src); client_socket->messageQueue->push_timeout(src, &timeout); // print_msg("_server_run_msg_rev push after", src); } break; default: logger->error("shm_socket._server_run_msg_rev: undefined message type."); } } return NULL; } void _client_close_conn_to_server(shm_socket_t *socket) { _shm_close_stream_socket(socket, false); } /** * client端的各种类型消息()在这里进程分拣 */ void *_client_run_msg_rev(void *_socket) { pthread_detach(pthread_self()); shm_socket_t *socket = (shm_socket_t *)_socket; struct timespec timeout = {1, 0}; shm_msg_t src; while (socket->queue->pop(src) == 0) { switch (src.type) { case SHM_SOCKET_CLOSE: _client_close_conn_to_server(socket); break; case SHM_COMMON_MSG: socket->messageQueue->push_timeout(src, &timeout); break; default: logger->error( "shm_socket._client_run_msg_rev: undefined message type."); } } return NULL; } int _shm_close_stream_socket(shm_socket_t *socket, bool notifyRemote) { socket->status = SHM_CONN_CLOSED; //给对方发送一个关闭连接的消息 struct timespec timeout = {1, 0}; shm_msg_t close_msg; close_msg.key = socket->key; close_msg.size = 0; close_msg.type = SHM_SOCKET_CLOSE; if (notifyRemote && socket->remoteQueue != NULL) { socket->remoteQueue->push_timeout(close_msg, &timeout); } if (socket->queue != NULL) { delete socket->queue; socket->queue = NULL; } if (socket->remoteQueue != NULL) { delete socket->remoteQueue; socket->remoteQueue = NULL; } if (socket->messageQueue != NULL) { delete socket->messageQueue; socket->messageQueue = NULL; } if (socket->acceptQueue != NULL) { delete socket->acceptQueue; socket->acceptQueue = NULL; } if (socket->clientSocketMap != NULL) { shm_socket_t *client_socket; for (auto iter = socket->clientSocketMap->begin(); iter != socket->clientSocketMap->end(); iter++) { client_socket = iter->second; client_socket->remoteQueue->push_timeout(close_msg, &timeout); delete client_socket->remoteQueue; client_socket->remoteQueue = NULL; delete client_socket->messageQueue; client_socket->messageQueue = NULL; free((void *)client_socket); } delete socket->clientSocketMap; } if (socket->dispatch_thread != 0) pthread_cancel(socket->dispatch_thread); return 0; } int _shm_close_dgram_socket(shm_socket_t *socket){ if(socket->queue != NULL) { delete socket->queue; socket->queue = NULL; } return 0; }