lichao
2021-05-19 34cd75f77d0ca94dbdba4e6cc9451fe4d33e78b3
box/center.cpp
@@ -21,7 +21,7 @@
#include "log.h"
#include "shm.h"
#include <chrono>
#include <set>
#include <unordered_map>
using namespace std::chrono;
using namespace std::chrono_literals;
@@ -33,11 +33,119 @@
namespace
{
typedef std::string ProcId;
typedef size_t ProcIndex; // max local procs.
const int kMaxProcs = 65536;
// record all procs ever registered, always grow, never remove.
// mainly for node to request msg allocation.
// use index instead of MQId to save some bits.
class ProcRecords
{
public:
   struct ProcRec {
      ProcId proc_;
      MQId ssn_ = 0;
   };
   ProcRecords() { procs_.reserve(kMaxProcs); }
   ProcIndex Put(const ProcId &proc_id, const MQId ssn)
   {
      if (procs_.size() >= kMaxProcs) {
         return -1;
      }
      auto pos_isnew = proc_index_.emplace(proc_id, procs_.size());
      int index = pos_isnew.first->second;
      if (pos_isnew.second) {
         procs_.emplace_back(ProcRec{proc_id, ssn});
      } else { // update ssn
         procs_[index].ssn_ = ssn;
      }
      return index;
   }
   const ProcRec &Get(const ProcIndex index) const
   {
      static ProcRec empty_rec;
      return (index < procs_.size()) ? procs_[index] : empty_rec;
   }
private:
   std::unordered_map<ProcId, size_t> proc_index_;
   std::vector<ProcRec> procs_;
};
class MsgRecords
{
   typedef int64_t MsgId;
   typedef int64_t Offset;
public:
   void RecordMsg(const MsgI &msg) { msgs_.emplace(msg.id(), msg.Offset()); }
   void FreeMsg(MsgId id)
   {
      auto pos = msgs_.find(id);
      if (pos != msgs_.end()) {
         ShmMsg(pos->second).Free();
         msgs_.erase(pos);
      } else {
         LOG_TRACE() << "ignore late free request.";
      }
   }
   void AutoRemove()
   {
      auto now = NowSec();
      if (now < time_to_clean_) {
         return;
      }
      // LOG_FUNCTION;
      time_to_clean_ = now + 1;
      int64_t limit = std::max(10000ul, msgs_.size() / 10);
      int64_t n = 0;
      auto it = msgs_.begin();
      while (it != msgs_.end() && --limit > 0) {
         ShmMsg msg(it->second);
         auto Free = [&]() {
            msg.Free();
            it = msgs_.erase(it);
            ++n;
         };
         int n = now - msg.timestamp();
         if (n < 10) {
            ++it;
         } else if (msg.Count() == 0) {
            Free();
         } else if (n > 60) {
            Free();
         }
      }
      if (n > 0) {
         LOG_DEBUG() << "~~~~~~~~~~~~~~~~ auto release msgs: " << n;
      }
   }
   size_t size() const { return msgs_.size(); }
   void DebugPrint() const
   {
      LOG_DEBUG() << "msgs : " << size();
      int i = 0;
      int total_count = 0;
      for (auto &kv : msgs_) {
         MsgI msg(kv.second);
         total_count += msg.Count();
         LOG_TRACE() << "  " << i++ << ": msg id: " << kv.first << ", offset: " << kv.second << ", count: " << msg.Count() << ", size: " << msg.Size();
      }
      LOG_DEBUG() << "total count: " << total_count;
   }
private:
   std::unordered_map<MsgId, Offset> msgs_;
   int64_t time_to_clean_ = 0;
};
//TODO check proc_id
class NodeCenter
{
public:
   typedef std::string ProcId;
   typedef MQId Address;
   typedef bhome_msg::ProcInfo ProcInfo;
   typedef std::function<void(Address const)> Cleaner;
@@ -74,22 +182,24 @@
   typedef std::unordered_map<Address, std::set<Topic>> AddressTopics;
   struct NodeInfo {
      ProcState state_;             // state
      std::set<Address> addrs_;     // registered mqs
      ProcInfo proc_;               //
      AddressTopics services_;      // address: topics
      AddressTopics subscriptions_; // address: topics
      ProcState state_;               // state
      std::map<MQId, int64_t> addrs_; // registered mqs
      ProcInfo proc_;                 //
      AddressTopics services_;        // address: topics
      AddressTopics subscriptions_;   // address: topics
   };
   typedef std::shared_ptr<NodeInfo> Node;
   typedef std::weak_ptr<NodeInfo> WeakNode;
   struct TopicDest {
      Address mq_;
      MQId mq_id_;
      int64_t mq_abs_addr_;
      WeakNode weak_node_;
      bool operator<(const TopicDest &a) const { return mq_ < a.mq_; }
      bool operator<(const TopicDest &a) const { return mq_id_ < a.mq_id_; }
   };
   inline MQId SrcAddr(const BHMsgHead &head) { return head.route(0).mq_id(); }
   inline bool MatchAddr(std::set<Address> const &addrs, const Address &addr) { return addrs.find(addr) != addrs.end(); }
   inline int64_t SrcAbsAddr(const BHMsgHead &head) { return head.route(0).abs_addr(); }
   inline bool MatchAddr(std::map<Address, int64_t> const &addrs, const Address &addr) { return addrs.find(addr) != addrs.end(); }
   NodeCenter(const std::string &id, const Cleaner &cleaner, const int64_t offline_time, const int64_t kill_time) :
       id_(id), cleaner_(cleaner), offline_time_(offline_time), kill_time_(kill_time), last_check_time_(0) {}
@@ -102,40 +212,151 @@
   // center name, no relative to shm.
   const std::string &id() const { return id_; }
   void OnNodeInit(SharedMemory &shm, const int64_t msg)
   int64_t OnNodeInit(ShmSocket &socket, const int64_t val)
   {
      MQId ssn = msg;
      LOG_FUNCTION;
      SharedMemory &shm = socket.shm();
      MQId ssn = (val >> 4) & MaskBits(56);
      int reply = EncodeCmd(eCmdNodeInitReply);
      if (nodes_.find(ssn) != nodes_.end()) {
         return; // ignore in exists.
         return reply; // ignore if exists.
      }
      auto UpdateRegInfo = [&](Node &node) {
         for (int i = 0; i < 10; ++i) {
            node->addrs_.insert(ssn + i);
         }
         node->state_.timestamp_ = NowSec() - offline_time_;
         node->state_.UpdateState(NowSec(), offline_time_, kill_time_);
         // create sockets.
         try {
            auto CreateSocket = [](SharedMemory &shm, const MQId id) {
               ShmSocket tmp(shm, true, id, 16);
            };
            // alloc(-1), node, server, sub, request,
            for (int i = -1; i < 4; ++i) {
               CreateSocket(shm, ssn + i);
               node->addrs_.insert(ssn + i);
            }
            ShmSocket tmp(shm, true, ssn, 16);
            node->addrs_.emplace(ssn, tmp.AbsAddr());
            return true;
         } catch (...) {
            return false;
         }
      };
      auto PrepareProcInit = [&](Node &node) {
         bool r = false;
         ShmMsg init_msg;
         DEFER1(init_msg.Release());
         MsgProcInit body;
         auto head = InitMsgHead(GetType(body), id(), ssn);
         return init_msg.Make(GetAllocSize(CalcAllocIndex(900))) &&
                init_msg.Fill(ShmMsg::Serialize(head, body)) &&
                SendAllocMsg(socket, {ssn, node->addrs_[ssn]}, init_msg);
      };
      Node node(new NodeInfo);
      if (UpdateRegInfo(node)) {
      if (UpdateRegInfo(node) && PrepareProcInit(node)) {
         reply |= (node->addrs_[ssn] << 4);
         nodes_[ssn] = node;
         LOG_INFO() << "new node ssn (" << ssn << ") init";
      } else {
         ShmSocket::Remove(shm, ssn);
      }
      return reply;
   }
   void RecordMsg(const MsgI &msg)
   {
      msg.reset_managed(true);
      msgs_.RecordMsg(msg);
   }
   bool SendAllocReply(ShmSocket &socket, const MQInfo &dest, const int64_t reply, const MsgI &msg)
   {
      RecordMsg(msg);
      auto onExpireFree = [this, msg](const SendQ::Data &) { msgs_.FreeMsg(msg.id()); };
      return socket.Send(dest, reply, onExpireFree);
   }
   bool SendAllocMsg(ShmSocket &socket, const MQInfo &dest, const MsgI &msg)
   {
      RecordMsg(msg);
      return socket.Send(dest, msg);
   }
   void OnAlloc(ShmSocket &socket, const int64_t val)
   {
      // LOG_FUNCTION;
      // 8bit size, 4bit socket index, 16bit proc index, 28bit id, ,4bit cmd+flag
      int64_t msg_id = (val >> 4) & MaskBits(28);
      int proc_index = (val >> 32) & MaskBits(16);
      int socket_index = ((val) >> 48) & MaskBits(4);
      auto proc_rec(procs_.Get(proc_index));
      if (proc_rec.proc_.empty()) {
         return;
      }
      MQInfo dest = {proc_rec.ssn_ + socket_index, 0};
      auto FindMq = [&]() {
         auto pos = nodes_.find(proc_rec.ssn_);
         if (pos != nodes_.end()) {
            for (auto &&mq : pos->second->addrs_) {
               if (mq.first == dest.id_) {
                  dest.offset_ = mq.second;
                  return true;
               }
            }
         }
         return false;
      };
      if (!FindMq()) { return; }
      auto size = GetAllocSize((val >> 52) & MaskBits(8));
      MsgI new_msg;
      if (new_msg.Make(size)) {
         // 31bit proc index, 28bit id, ,4bit cmd+flag
         int64_t reply = (new_msg.Offset() << 32) | (msg_id << 4) | EncodeCmd(eCmdAllocReply0);
         SendAllocReply(socket, dest, reply, new_msg);
      } else {
         int64_t reply = (msg_id << 4) | EncodeCmd(eCmdAllocReply0); // send empty, ack failure.
         socket.Send(dest, reply);
      }
   }
   void OnFree(ShmSocket &socket, const int64_t val)
   {
      int64_t msg_id = (val >> 4) & MaskBits(31);
      msgs_.FreeMsg(msg_id);
   }
   bool OnCommand(ShmSocket &socket, const int64_t val)
   {
      assert(IsCmd(val));
      int cmd = DecodeCmd(val);
      switch (cmd) {
      case eCmdAllocRequest0: OnAlloc(socket, val); break;
      case eCmdFree: OnFree(socket, val); break;
      default: return false;
      }
      return true;
   }
   MsgProcInitReply ProcInit(const BHMsgHead &head, MsgProcInit &msg)
   {
      LOG_DEBUG() << "center got proc init.";
      auto pos = nodes_.find(head.ssn_id());
      if (pos == nodes_.end()) {
         return MakeReply<MsgProcInitReply>(eNotFound, "Node Not Initialised");
      }
      auto index = procs_.Put(head.proc_id(), head.ssn_id());
      auto reply(MakeReply<MsgProcInitReply>(eSuccess));
      reply.set_proc_index(index);
      auto &node = pos->second;
      try {
         for (int i = 0; i < msg.extra_mq_num(); ++i) {
            ShmSocket tmp(BHomeShm(), true, head.ssn_id() + i + 1, 16);
            node->addrs_.emplace(tmp.id(), tmp.AbsAddr());
            auto addr = reply.add_extra_mqs();
            addr->set_mq_id(tmp.id());
            addr->set_abs_addr(tmp.AbsAddr());
         }
         return reply;
      } catch (...) {
         LOG_ERROR() << "proc init create mq error";
         return MakeReply<MsgProcInitReply>(eError, "Create mq failed.");
      }
   }
@@ -150,24 +371,19 @@
         // when node restart, ssn will change,
         // and old node will be removed after timeout.
         auto UpdateRegInfo = [&](Node &node) {
            node->addrs_.insert(SrcAddr(head));
            for (auto &addr : msg.addrs()) {
               node->addrs_.insert(addr.mq_id());
            }
            node->proc_.Swap(msg.mutable_proc());
            node->state_.timestamp_ = head.timestamp();
            node->state_.UpdateState(NowSec(), offline_time_, kill_time_);
         };
         auto pos = nodes_.find(ssn);
         if (pos != nodes_.end()) { // update
            Node &node = pos->second;
            UpdateRegInfo(node);
         } else {
            Node node(new NodeInfo);
            UpdateRegInfo(node);
            nodes_[ssn] = node;
         if (pos == nodes_.end()) {
            return MakeReply(eInvalidInput, "invalid session.");
         }
         // update proc info
         Node &node = pos->second;
         UpdateRegInfo(node);
         LOG_DEBUG() << "node (" << head.proc_id() << ") ssn (" << ssn << ")";
         auto old = online_node_addr_map_.find(head.proc_id());
@@ -234,11 +450,11 @@
             auto src = SrcAddr(head);
             auto &topics = msg.topics().topic_list();
             node->services_[src].insert(topics.begin(), topics.end());
             TopicDest dest = {src, node};
             TopicDest dest = {src, SrcAbsAddr(head), node};
             for (auto &topic : topics) {
                service_map_[topic].insert(dest);
             }
             LOG_DEBUG() << "node " << node->proc_.proc_id() << " ssn " << *node->addrs_.begin() << " serve " << topics.size() << " topics:\n";
             LOG_DEBUG() << "node " << node->proc_.proc_id() << " ssn " << node->addrs_.begin()->first << " serve " << topics.size() << " topics:\n";
             for (auto &topic : topics) {
                LOG_DEBUG() << "\t" << topic;
             }
@@ -263,12 +479,52 @@
         return MakeReply(eSuccess);
      });
   }
   MsgQueryProcReply QueryProc(const BHMsgHead &head, const MsgQueryProc &req)
   {
      typedef MsgQueryProcReply Reply;
      auto query = [&](Node self) -> Reply {
         auto Add1 = [](Reply &reply, Node node) {
            auto info = reply.add_proc_list();
            *info->mutable_proc() = node->proc_;
            info->set_online(node->state_.flag_ == kStateNormal);
            for (auto &addr_topics : node->services_) {
               for (auto &topic : addr_topics.second) {
                  info->mutable_topics()->add_topic_list(topic);
               }
            }
         };
         if (!req.proc_id().empty()) {
            auto pos = online_node_addr_map_.find(req.proc_id());
            if (pos == online_node_addr_map_.end()) {
               return MakeReply<Reply>(eNotFound, "proc not found.");
            } else {
               auto node_pos = nodes_.find(pos->second);
               if (node_pos == nodes_.end()) {
                  return MakeReply<Reply>(eNotFound, "proc node not found.");
               } else {
                  auto reply = MakeReply<Reply>(eSuccess);
                  Add1(reply, node_pos->second);
                  return reply;
               }
            }
         } else {
            Reply reply(MakeReply<Reply>(eSuccess));
            for (auto &kv : nodes_) {
               Add1(reply, kv.second);
            }
            return reply;
         }
      };
      return HandleMsg<Reply>(head, query);
   }
   MsgQueryTopicReply QueryTopic(const BHMsgHead &head, const MsgQueryTopic &req)
   {
      typedef MsgQueryTopicReply Reply;
      auto query = [&](Node self) -> MsgQueryTopicReply {
      auto query = [&](Node self) -> Reply {
         auto pos = service_map_.find(req.topic());
         if (pos != service_map_.end() && !pos->second.empty()) {
            auto &clients = pos->second;
@@ -278,7 +534,8 @@
               if (dest_node && Valid(*dest_node)) {
                  auto node_addr = reply.add_node_address();
                  node_addr->set_proc_id(dest_node->proc_.proc_id());
                  node_addr->mutable_addr()->set_mq_id(dest.mq_);
                  node_addr->mutable_addr()->set_mq_id(dest.mq_id_);
                  node_addr->mutable_addr()->set_abs_addr(dest.mq_abs_addr_);
               }
            }
            return reply;
@@ -296,7 +553,7 @@
         auto src = SrcAddr(head);
         auto &topics = msg.topics().topic_list();
         node->subscriptions_[src].insert(topics.begin(), topics.end());
         TopicDest dest = {src, node};
         TopicDest dest = {src, SrcAbsAddr(head), node};
         for (auto &topic : topics) {
            subscribe_map_[topic].insert(dest);
         }
@@ -319,7 +576,7 @@
         };
         if (pos != node->subscriptions_.end()) {
            const TopicDest &dest = {src, node};
            const TopicDest &dest = {src, SrcAbsAddr(head), node};
            auto &topics = msg.topics().topic_list();
            // clear node sub records;
            for (auto &topic : topics) {
@@ -376,13 +633,14 @@
   void OnTimer()
   {
      CheckNodes();
      msgs_.AutoRemove();
   }
private:
   void CheckNodes()
   {
      auto now = NowSec();
      if (now - last_check_time_ < 1) { return; }
      if (now <= last_check_time_) { return; }
      last_check_time_ = now;
      auto it = nodes_.begin();
@@ -396,6 +654,7 @@
            ++it;
         }
      }
      msgs_.DebugPrint();
   }
   bool CanHeartbeat(const NodeInfo &node)
   {
@@ -414,7 +673,7 @@
   {
      auto EraseMapRec = [&node](auto &rec_map, auto &node_rec) {
         for (auto &addr_topics : node_rec) {
            TopicDest dest{addr_topics.first, node};
            TopicDest dest{addr_topics.first, 0, node}; // abs_addr is not used.
            for (auto &topic : addr_topics.second) {
               auto pos = rec_map.find(topic);
               if (pos != rec_map.end()) {
@@ -438,7 +697,7 @@
      }
      for (auto &addr : node->addrs_) {
         cleaner_(addr);
         cleaner_(addr.first);
      }
      node->addrs_.clear();
@@ -448,7 +707,10 @@
   std::unordered_map<Topic, Clients> service_map_;
   std::unordered_map<Topic, Clients> subscribe_map_;
   std::unordered_map<Address, Node> nodes_;
   std::unordered_map<std::string, Address> online_node_addr_map_;
   std::unordered_map<ProcId, Address> online_node_addr_map_;
   ProcRecords procs_; // To get a short index for msg alloc.
   MsgRecords msgs_;   // record all msgs alloced.
   Cleaner cleaner_; // remove mqs.
   int64_t offline_time_;
   int64_t kill_time_;
@@ -483,49 +745,59 @@
          msg, head, [&](auto &body) { return center->MsgTag(head, body); }, replyer); \
      return true;
auto MakeReplyer(ShmSocket &socket, BHMsgHead &head, const std::string &proc_id)
auto MakeReplyer(ShmSocket &socket, BHMsgHead &head, Synced<NodeCenter> &center)
{
   return [&](auto &&rep_body) {
      auto reply_head(InitMsgHead(GetType(rep_body), proc_id, head.ssn_id(), head.msg_id()));
      auto remote = head.route(0).mq_id();
      socket.Send(remote, reply_head, rep_body);
      auto reply_head(InitMsgHead(GetType(rep_body), center->id(), head.ssn_id(), head.msg_id()));
      MQInfo remote = {head.route(0).mq_id(), head.route(0).abs_addr()};
      MsgI msg;
      if (msg.Make(reply_head, rep_body)) {
         DEFER1(msg.Release(););
         center->SendAllocMsg(socket, remote, msg);
      }
   };
}
bool AddCenter(std::shared_ptr<Synced<NodeCenter>> center_ptr)
{
   auto OnNodeInit = [center_ptr](ShmSocket &socket, MsgI &msg) {
   // command
   auto OnCommand = [center_ptr](ShmSocket &socket, ShmMsgQueue::RawData &cmd) -> bool {
      auto &center = *center_ptr;
      center->OnNodeInit(socket.shm(), msg.Offset());
      return IsCmd(cmd) && center->OnCommand(socket, cmd);
   };
   auto Nothing = [](ShmSocket &socket) {};
   BHCenter::Install("#centetr.Init", OnNodeInit, Nothing, BHInitAddress(), 16);
   // now we can talk.
   auto OnCenterIdle = [center_ptr](ShmSocket &socket) {
      auto &center = *center_ptr;
      auto onInit = [&](const int64_t request) {
         return center->OnNodeInit(socket, request);
      };
      BHCenterHandleInit(onInit);
      center->OnTimer();
   };
   auto OnCenter = [=](ShmSocket &socket, MsgI &msg, BHMsgHead &head) -> bool {
      auto &center = *center_ptr;
      auto replyer = MakeReplyer(socket, head, center->id());
      auto replyer = MakeReplyer(socket, head, center);
      switch (head.type()) {
         CASE_ON_MSG_TYPE(ProcInit);
         CASE_ON_MSG_TYPE(Register);
         CASE_ON_MSG_TYPE(Heartbeat);
         CASE_ON_MSG_TYPE(Unregister);
         CASE_ON_MSG_TYPE(RegisterRPC);
         CASE_ON_MSG_TYPE(QueryTopic);
         CASE_ON_MSG_TYPE(QueryProc);
      default: return false;
      }
   };
   BHCenter::Install("#center.main", OnCenter, OnCenterIdle, BHTopicCenterAddress(), 1000);
   BHCenter::Install("#center.main", OnCenter, OnCommand, OnCenterIdle, BHTopicCenterAddress(), 1000);
   auto OnBusIdle = [=](ShmSocket &socket) {};
   auto OnBusCmd = [=](ShmSocket &socket, ShmMsgQueue::RawData &val) { return false; };
   auto OnPubSub = [=](ShmSocket &socket, MsgI &msg, BHMsgHead &head) -> bool {
      auto &center = *center_ptr;
      auto replyer = MakeReplyer(socket, head, center->id());
      auto replyer = MakeReplyer(socket, head, center);
      auto OnPublish = [&]() {
         MsgPublish pub;
         NodeCenter::Clients clients;
@@ -545,7 +817,7 @@
               if (node) {
                  // should also make sure that mq is not killed before msg expires.
                  // it would be ok if (kill_time - offline_time) is longer than expire time.
                  socket.Send(cli.mq_, msg);
                  socket.Send({cli.mq_id_, cli.mq_abs_addr_}, msg);
                  ++it;
               } else {
                  it = clients.erase(it);
@@ -561,7 +833,7 @@
      }
   };
   BHCenter::Install("#center.bus", OnPubSub, OnBusIdle, BHTopicBusAddress(), 1000);
   BHCenter::Install("#center.bus", OnPubSub, OnBusCmd, OnBusIdle, BHTopicBusAddress(), 1000);
   return true;
}
@@ -576,14 +848,9 @@
   return rec;
}
bool BHCenter::Install(const std::string &name, MsgHandler handler, IdleHandler idle, const MQId mqid, const int mq_len)
bool BHCenter::Install(const std::string &name, MsgHandler handler, RawHandler raw_handler, IdleHandler idle, const MQInfo &mq, const int mq_len)
{
   Centers()[name] = CenterInfo{name, handler, MsgIHandler(), idle, mqid, mq_len};
   return true;
}
bool BHCenter::Install(const std::string &name, MsgIHandler handler, IdleHandler idle, const MQId mqid, const int mq_len)
{
   Centers()[name] = CenterInfo{name, MsgHandler(), handler, idle, mqid, mq_len};
   Centers()[name] = CenterInfo{name, handler, raw_handler, idle, mq, mq_len};
   return true;
}
@@ -596,12 +863,13 @@
      }
   };
   auto center_ptr = std::make_shared<Synced<NodeCenter>>("#bhome_center", gc, 6s, 6s * 2);
   auto nsec = seconds(NodeTimeoutSec());
   auto center_ptr = std::make_shared<Synced<NodeCenter>>("#bhome_center", gc, nsec, nsec * 3); // *3 to allow other clients to finish sending msgs.
   AddCenter(center_ptr);
   for (auto &kv : Centers()) {
      auto &info = kv.second;
      sockets_[info.name_] = std::make_shared<ShmSocket>(shm, info.mqid_, info.mq_len_);
      sockets_[info.name_] = std::make_shared<ShmSocket>(info.mq_.offset_, shm, info.mq_.id_);
   }
}
@@ -609,11 +877,7 @@
{
   for (auto &kv : Centers()) {
      auto &info = kv.second;
      if (info.handler_) {
         sockets_[info.name_]->Start(info.handler_, info.idle_);
      } else {
         sockets_[info.name_]->Start(info.raw_handler_, info.idle_);
      }
      sockets_[info.name_]->Start(1, info.handler_, info.raw_handler_, info.idle_);
   }
   return true;