lichao
2021-05-27 026bbfaf2b5d73a26b8e2fa49158883ef64c211b
box/tcp_connection.cpp
@@ -17,78 +17,198 @@
 */
#include "tcp_connection.h"
#include "log.h"
#include "msg.h"
#include "node_center.h"
#include "shm_socket.h"
namespace
{
template <class C>
auto Buffer(C &c) { return boost::asio::buffer(c.data(), c.size()); }
auto Buffer(C &c, size_t offset = 0) { return boost::asio::buffer(c.data() + offset, c.size() - offset); }
using boost::asio::async_read;
using boost::asio::async_write;
typedef std::function<void()> VoidHandler;
typedef std::function<void(size_t)> SizeHandler;
template <class T, class... Param>
auto TcpCallback(T &conn, std::function<void(Param...)> const &func)
{
   auto self(conn.shared_from_this());
   return [self, func](bserror_t ec, Param... size) {
      if (!ec) {
         func(size...);
      } else {
         self->OnError(ec);
      }
   };
}
template <class T>
auto TcpCB(T &conn, VoidHandler const &func) { return TcpCallback(conn, func); }
template <class T>
auto TcpCBSize(T &conn, SizeHandler const &func) { return TcpCallback(conn, func); }
template <class T>
auto TcpCBSize(T &conn, VoidHandler const &func)
{
   return TcpCBSize(conn, [func](size_t) { func(); });
}
bool CheckData(std::vector<char> &buffer, const uint32_t len, BHMsgHead &head, std::string &body_content)
{
   const char *p = buffer.data();
   LOG_DEBUG() << "msg len " << len;
   if (4 > len) { return false; }
   uint32_t head_len = Get32(p);
   LOG_DEBUG() << "head_len " << head_len;
   if (head_len > 1024 * 4) {
      throw std::runtime_error("unexpected tcp reply data.");
   }
   auto before_body = 4 + head_len + 4;
   if (before_body > len) {
      if (before_body > buffer.size()) {
         buffer.resize(before_body);
      }
      return false;
   }
   if (!head.ParseFromArray(p + 4, head_len)) {
      throw std::runtime_error("tcp recv invalid reply head.");
   }
   uint32_t body_len = Get32(p + 4 + head_len);
   buffer.resize(before_body + body_len);
   if (buffer.size() > len) { return false; }
   body_content.assign(p + before_body, body_len);
   return true;
}
} // namespace
TcpRequest1::TcpRequest1(boost::asio::io_context &io, tcp::endpoint const &addr, std::string request) :
    socket_(io), remote_(addr), request_(std::move(request)) {}
void TcpRequest1::Connect()
/// request -----------------------------------------------------------
void TcpRequest1::OnError(bserror_t ec)
{
   auto self = shared_from_this();
   socket_.async_connect(remote_, [this, self](bserror_t ec) {
      if (!ec) {
         SendRequest();
      } else {
         LOG_ERROR() << "connect error " << ec;
         Close();
      }
   });
   LOG_ERROR() << "tcp client error: " << ec;
   Close();
}
void TcpRequest1::Start()
{
   auto readReply = [this]() {
      recv_buffer_.resize(1000);
      recv_len_ = 0;
      socket_.async_read_some(Buffer(recv_buffer_), TcpCBSize(*this, [this](size_t size) { OnRead(size); }));
   };
   auto request = [this, readReply]() { async_write(socket_, Buffer(request_), TcpCBSize(*this, readReply)); };
   socket_.async_connect(remote_, TcpCB(*this, request));
}
void TcpRequest1::Close()
{
   LOG_DEBUG() << "client close";
   socket_.close();
}
void TcpRequest1::OnRead(size_t size)
{
   LOG_DEBUG() << "reply data: " << recv_buffer_.data() + recv_len_;
   recv_len_ += size;
   BHMsgHead head;
   std::string body_content;
   bool recv_done = false;
   try {
      recv_done = CheckData(recv_buffer_, recv_len_, head, body_content);
   } catch (std::exception &e) {
      LOG_ERROR() << e.what();
      Close();
      return;
   }
   if (recv_done) {
      // just pass to client, no check, client will check it anyway.
      LOG_DEBUG() << "route size: " << head.route_size();
      if (head.route_size() < 1) { return; }
      auto &back = head.route(head.route_size() - 1);
      MQInfo dest = {back.mq_id(), back.abs_addr()};
      head.mutable_route()->RemoveLast();
      LOG_DEBUG() << "tcp got reply, pass to shm: " << dest.id_ << ", " << dest.offset_;
      MsgRequestTopicReply reply;
      if (reply.ParseFromString(body_content)) {
         LOG_DEBUG() << "err msg: " << reply.errmsg().errstring();
         LOG_DEBUG() << "content : " << reply.data();
      }
      Close();
      return;
      shm_socket_.Send(dest, std::string(recv_buffer_.data(), recv_buffer_.size()));
   } else { // read again
      LOG_DEBUG() << "not complete, read again " << recv_buffer_.size();
      socket_.async_read_some(Buffer(recv_buffer_, recv_len_), TcpCBSize(*this, [this](size_t size) { OnRead(size); }));
   }
}
/// reply --------------------------------------------------------------
void TcpReply1::OnError(bserror_t ec) { Close(); }
void TcpReply1::Close()
{
   LOG_DEBUG() << "server close.";
   socket_.close();
}
void TcpRequest1::SendRequest()
{
   LOG_INFO() << "client sending request " << request_;
   auto self = shared_from_this();
   async_write(socket_, Buffer(request_), [this, self](bserror_t ec, size_t) {
      if (!ec) {
         ReadReply();
      } else {
         Close();
      }
   });
}
void TcpRequest1::ReadReply()
{
   buffer_.resize(1000);
   auto self = shared_from_this();
   socket_.async_read_some(Buffer(buffer_), [this, self](bserror_t ec, size_t size) {
      if (!ec) {
         printf("reply data: %s\n", buffer_.data());
      } else {
         Close();
      }
   });
}
TcpReply1::TcpReply1(tcp::socket sock) :
    socket_(std::move(sock)) {}
void TcpReply1::Start()
{
   LOG_INFO() << "server session reading...";
   recv_buffer_.resize(1000);
   auto self(shared_from_this());
   socket_.async_read_some(Buffer(recv_buffer_), [this, self](bserror_t ec, size_t size) {
      LOG_INFO() << "server read : " << recv_buffer_.data();
      // fake reply
      if (!ec) {
         send_buffer_ = std::string(recv_buffer_.data(), size) + " reply";
         async_write(socket_, Buffer(send_buffer_), [this, self](bserror_t ec, size_t size) {
            socket_.close();
         });
   socket_.async_read_some(Buffer(recv_buffer_), TcpCBSize(*this, [this](size_t size) { OnRead(size); }));
}
void TcpReply1::OnRead(size_t size)
{
   recv_len_ += size;
   BHMsgHead head;
   std::string body_content;
   bool recv_done = false;
   try {
      recv_done = CheckData(recv_buffer_, recv_len_, head, body_content);
   } catch (std::exception &e) {
      LOG_ERROR() << e.what();
      Close();
      return;
   }
   auto ParseBody = [&](auto &req) {
      const char *p = recv_buffer_.data();
      uint32_t size = Get32(p);
      p += 4;
      p += size;
      size = Get32(p);
      p += 4;
      return req.ParseFromArray(p, size);
   };
   if (recv_done) {
      LOG_DEBUG() << "request data: " << size;
      auto self(shared_from_this());
      MQInfo remote = {head.dest().mq_id(), head.dest().abs_addr()};
      if (remote.id_ && remote.offset_) {
         auto onRecv = [this, self](ShmSocket &sock, MsgI &imsg, BHMsgHead &head) {
            send_buffer_ = imsg.content();
            async_write(socket_, Buffer(send_buffer_), TcpCBSize(*this, [this]() { Close(); }));
         };
         auto &scenter = *pscenter_;
         if (!scenter->ProxyMsg(remote, head, body_content, onRecv)) {
            send_buffer_ = "fake reply";
            async_write(socket_, Buffer(send_buffer_), TcpCBSize(*this, [this]() { Close(); }));
         }
      } else {
         socket_.close();
         LOG_DEBUG() << "no address";
         send_buffer_ = "no address";
         async_write(socket_, Buffer(send_buffer_), TcpCBSize(*this, [this]() { Close(); }));
      }
   });
}
   } else { // read again
      LOG_DEBUG() << "not complete, read again " << recv_buffer_.size();
      socket_.async_read_some(Buffer(recv_buffer_, recv_len_), TcpCBSize(*this, [this](size_t size) { OnRead(size); }));
   }
};