lichao
2021-05-20 101b5cf85397ef9350aaedd12cfcf9fd3d07a565
box/center.cpp
@@ -16,12 +16,8 @@
 * =====================================================================================
 */
#include "center.h"
#include "bh_util.h"
#include "defs.h"
#include "log.h"
#include "shm.h"
#include "node_center.h"
#include <chrono>
#include <unordered_map>
using namespace std::chrono;
using namespace std::chrono_literals;
@@ -33,689 +29,7 @@
namespace
{
typedef std::string ProcId;
typedef size_t ProcIndex; // max local procs.
const int kMaxProcs = 65536;
// record all procs ever registered, always grow, never remove.
// mainly for node to request msg allocation.
// use index instead of MQId to save some bits.
class ProcRecords
{
public:
   struct ProcRec {
      ProcId proc_;
      MQId ssn_ = 0;
   };
   ProcRecords() { procs_.reserve(kMaxProcs); }
   ProcIndex Put(const ProcId &proc_id, const MQId ssn)
   {
      if (procs_.size() >= kMaxProcs) {
         return -1;
      }
      auto pos_isnew = proc_index_.emplace(proc_id, procs_.size());
      int index = pos_isnew.first->second;
      if (pos_isnew.second) {
         procs_.emplace_back(ProcRec{proc_id, ssn});
      } else { // update ssn
         procs_[index].ssn_ = ssn;
      }
      return index;
   }
   const ProcRec &Get(const ProcIndex index) const
   {
      static ProcRec empty_rec;
      return (index < procs_.size()) ? procs_[index] : empty_rec;
   }
private:
   std::unordered_map<ProcId, size_t> proc_index_;
   std::vector<ProcRec> procs_;
};
class MsgRecords
{
   typedef int64_t MsgId;
   typedef int64_t Offset;
public:
   void RecordMsg(const MsgI &msg) { msgs_.emplace(msg.id(), msg.Offset()); }
   void FreeMsg(MsgId id)
   {
      auto pos = msgs_.find(id);
      if (pos != msgs_.end()) {
         ShmMsg(pos->second).Free();
         msgs_.erase(pos);
      } else {
         LOG_TRACE() << "ignore late free request.";
      }
   }
   void AutoRemove()
   {
      auto now = NowSec();
      if (now < time_to_clean_) {
         return;
      }
      // LOG_FUNCTION;
      time_to_clean_ = now + 1;
      int64_t limit = std::max(10000ul, msgs_.size() / 10);
      int64_t n = 0;
      auto it = msgs_.begin();
      while (it != msgs_.end() && --limit > 0) {
         ShmMsg msg(it->second);
         auto Free = [&]() {
            msg.Free();
            it = msgs_.erase(it);
            ++n;
         };
         int n = now - msg.timestamp();
         if (n < 10) {
            ++it;
         } else if (msg.Count() == 0) {
            Free();
         } else if (n > 60) {
            Free();
         }
      }
      if (n > 0) {
         LOG_DEBUG() << "~~~~~~~~~~~~~~~~ auto release msgs: " << n;
      }
   }
   size_t size() const { return msgs_.size(); }
   void DebugPrint() const
   {
      LOG_DEBUG() << "msgs : " << size();
      int i = 0;
      int total_count = 0;
      for (auto &kv : msgs_) {
         MsgI msg(kv.second);
         total_count += msg.Count();
         LOG_TRACE() << "  " << i++ << ": msg id: " << kv.first << ", offset: " << kv.second << ", count: " << msg.Count() << ", size: " << msg.Size();
      }
      LOG_DEBUG() << "total count: " << total_count;
   }
private:
   std::unordered_map<MsgId, Offset> msgs_;
   int64_t time_to_clean_ = 0;
};
//TODO check proc_id
class NodeCenter
{
public:
   typedef MQId Address;
   typedef bhome_msg::ProcInfo ProcInfo;
   typedef std::function<void(Address const)> Cleaner;
private:
   enum {
      kStateInvalid,
      kStateNormal,
      kStateOffline,
      kStateKillme,
   };
   struct ProcState {
      int64_t timestamp_ = 0;
      uint32_t flag_ = 0; // reserved
      void PutOffline(const int64_t offline_time)
      {
         timestamp_ = NowSec() - offline_time;
         flag_ = kStateOffline;
      }
      void UpdateState(const int64_t now, const int64_t offline_time, const int64_t kill_time)
      {
         auto diff = now - timestamp_;
         LOG_DEBUG() << "state " << this << " diff: " << diff;
         if (diff < offline_time) {
            flag_ = kStateNormal;
         } else if (diff < kill_time) {
            flag_ = kStateOffline;
         } else {
            flag_ = kStateKillme;
         }
      }
   };
   typedef std::unordered_map<Address, std::set<Topic>> AddressTopics;
   struct NodeInfo {
      ProcState state_;               // state
      std::map<MQId, int64_t> addrs_; // registered mqs
      ProcInfo proc_;                 //
      AddressTopics services_;        // address: topics
      AddressTopics subscriptions_;   // address: topics
   };
   typedef std::shared_ptr<NodeInfo> Node;
   typedef std::weak_ptr<NodeInfo> WeakNode;
   struct TopicDest {
      MQId mq_id_;
      int64_t mq_abs_addr_;
      WeakNode weak_node_;
      bool operator<(const TopicDest &a) const { return mq_id_ < a.mq_id_; }
   };
   inline MQId SrcAddr(const BHMsgHead &head) { return head.route(0).mq_id(); }
   inline int64_t SrcAbsAddr(const BHMsgHead &head) { return head.route(0).abs_addr(); }
   inline bool MatchAddr(std::map<Address, int64_t> const &addrs, const Address &addr) { return addrs.find(addr) != addrs.end(); }
   NodeCenter(const std::string &id, const Cleaner &cleaner, const int64_t offline_time, const int64_t kill_time) :
       id_(id), cleaner_(cleaner), offline_time_(offline_time), kill_time_(kill_time), last_check_time_(0) {}
public:
   typedef std::set<TopicDest> Clients;
   NodeCenter(const std::string &id, const Cleaner &cleaner, const steady_clock::duration offline_time, const steady_clock::duration kill_time) :
       NodeCenter(id, cleaner, duration_cast<seconds>(offline_time).count(), duration_cast<seconds>(kill_time).count()) {}
   // center name, no relative to shm.
   const std::string &id() const { return id_; }
   int64_t OnNodeInit(ShmSocket &socket, const int64_t val)
   {
      LOG_FUNCTION;
      SharedMemory &shm = socket.shm();
      MQId ssn = (val >> 4) & MaskBits(56);
      int reply = EncodeCmd(eCmdNodeInitReply);
      if (nodes_.find(ssn) != nodes_.end()) {
         return reply; // ignore if exists.
      }
      auto UpdateRegInfo = [&](Node &node) {
         node->state_.timestamp_ = NowSec() - offline_time_;
         node->state_.UpdateState(NowSec(), offline_time_, kill_time_);
         // create sockets.
         try {
            ShmSocket tmp(shm, true, ssn, 16);
            node->addrs_.emplace(ssn, tmp.AbsAddr());
            return true;
         } catch (...) {
            return false;
         }
      };
      auto PrepareProcInit = [&](Node &node) {
         bool r = false;
         ShmMsg init_msg;
         DEFER1(init_msg.Release());
         MsgProcInit body;
         auto head = InitMsgHead(GetType(body), id(), ssn);
         return init_msg.Make(GetAllocSize(CalcAllocIndex(900))) &&
                init_msg.Fill(ShmMsg::Serialize(head, body)) &&
                SendAllocMsg(socket, {ssn, node->addrs_[ssn]}, init_msg);
      };
      Node node(new NodeInfo);
      if (UpdateRegInfo(node) && PrepareProcInit(node)) {
         reply |= (node->addrs_[ssn] << 4);
         nodes_[ssn] = node;
         LOG_INFO() << "new node ssn (" << ssn << ") init";
      } else {
         ShmSocket::Remove(shm, ssn);
      }
      return reply;
   }
   void RecordMsg(const MsgI &msg)
   {
      msg.reset_managed(true);
      msgs_.RecordMsg(msg);
   }
   bool SendAllocReply(ShmSocket &socket, const MQInfo &dest, const int64_t reply, const MsgI &msg)
   {
      RecordMsg(msg);
      auto onExpireFree = [this, msg](const SendQ::Data &) { msgs_.FreeMsg(msg.id()); };
      return socket.Send(dest, reply, onExpireFree);
   }
   bool SendAllocMsg(ShmSocket &socket, const MQInfo &dest, const MsgI &msg)
   {
      RecordMsg(msg);
      return socket.Send(dest, msg);
   }
   void OnAlloc(ShmSocket &socket, const int64_t val)
   {
      // LOG_FUNCTION;
      // 8bit size, 4bit socket index, 16bit proc index, 28bit id, ,4bit cmd+flag
      int64_t msg_id = (val >> 4) & MaskBits(28);
      int proc_index = (val >> 32) & MaskBits(16);
      int socket_index = ((val) >> 48) & MaskBits(4);
      auto proc_rec(procs_.Get(proc_index));
      if (proc_rec.proc_.empty()) {
         return;
      }
      MQInfo dest = {proc_rec.ssn_ + socket_index, 0};
      auto FindMq = [&]() {
         auto pos = nodes_.find(proc_rec.ssn_);
         if (pos != nodes_.end()) {
            for (auto &&mq : pos->second->addrs_) {
               if (mq.first == dest.id_) {
                  dest.offset_ = mq.second;
                  return true;
               }
            }
         }
         return false;
      };
      if (!FindMq()) { return; }
      auto size = GetAllocSize((val >> 52) & MaskBits(8));
      MsgI new_msg;
      if (new_msg.Make(size)) {
         // 31bit proc index, 28bit id, ,4bit cmd+flag
         int64_t reply = (new_msg.Offset() << 32) | (msg_id << 4) | EncodeCmd(eCmdAllocReply0);
         SendAllocReply(socket, dest, reply, new_msg);
      } else {
         int64_t reply = (msg_id << 4) | EncodeCmd(eCmdAllocReply0); // send empty, ack failure.
         socket.Send(dest, reply);
      }
   }
   void OnFree(ShmSocket &socket, const int64_t val)
   {
      int64_t msg_id = (val >> 4) & MaskBits(31);
      msgs_.FreeMsg(msg_id);
   }
   bool OnCommand(ShmSocket &socket, const int64_t val)
   {
      assert(IsCmd(val));
      int cmd = DecodeCmd(val);
      switch (cmd) {
      case eCmdAllocRequest0: OnAlloc(socket, val); break;
      case eCmdFree: OnFree(socket, val); break;
      default: return false;
      }
      return true;
   }
   MsgProcInitReply ProcInit(const BHMsgHead &head, MsgProcInit &msg)
   {
      LOG_DEBUG() << "center got proc init.";
      auto pos = nodes_.find(head.ssn_id());
      if (pos == nodes_.end()) {
         return MakeReply<MsgProcInitReply>(eNotFound, "Node Not Initialised");
      }
      auto index = procs_.Put(head.proc_id(), head.ssn_id());
      auto reply(MakeReply<MsgProcInitReply>(eSuccess));
      reply.set_proc_index(index);
      auto &node = pos->second;
      try {
         for (int i = 0; i < msg.extra_mq_num(); ++i) {
            ShmSocket tmp(BHomeShm(), true, head.ssn_id() + i + 1, 16);
            node->addrs_.emplace(tmp.id(), tmp.AbsAddr());
            auto addr = reply.add_extra_mqs();
            addr->set_mq_id(tmp.id());
            addr->set_abs_addr(tmp.AbsAddr());
         }
         return reply;
      } catch (...) {
         LOG_ERROR() << "proc init create mq error";
         return MakeReply<MsgProcInitReply>(eError, "Create mq failed.");
      }
   }
   MsgCommonReply Register(const BHMsgHead &head, MsgRegister &msg)
   {
      if (msg.proc().proc_id() != head.proc_id()) {
         return MakeReply(eInvalidInput, "invalid proc id.");
      }
      try {
         MQId ssn = head.ssn_id();
         // when node restart, ssn will change,
         // and old node will be removed after timeout.
         auto UpdateRegInfo = [&](Node &node) {
            node->proc_.Swap(msg.mutable_proc());
            node->state_.timestamp_ = head.timestamp();
            node->state_.UpdateState(NowSec(), offline_time_, kill_time_);
         };
         auto pos = nodes_.find(ssn);
         if (pos == nodes_.end()) {
            return MakeReply(eInvalidInput, "invalid session.");
         }
         // update proc info
         Node &node = pos->second;
         UpdateRegInfo(node);
         LOG_DEBUG() << "node (" << head.proc_id() << ") ssn (" << ssn << ")";
         auto old = online_node_addr_map_.find(head.proc_id());
         if (old != online_node_addr_map_.end()) { // old session
            auto &old_ssn = old->second;
            if (old_ssn != ssn) {
               nodes_[old_ssn]->state_.PutOffline(offline_time_);
               LOG_DEBUG() << "put node (" << nodes_[old_ssn]->proc_.proc_id() << ") ssn (" << old->second << ") offline";
               old_ssn = ssn;
            }
         } else {
            online_node_addr_map_.emplace(head.proc_id(), ssn);
         }
         return MakeReply(eSuccess);
      } catch (...) {
         return MakeReply(eError, "register node error.");
      }
   }
   template <class Reply, class Func>
   Reply HandleMsg(const BHMsgHead &head, Func const &op)
   {
      try {
         auto pos = nodes_.find(head.ssn_id());
         if (pos == nodes_.end()) {
            return MakeReply<Reply>(eNotRegistered, "Node is not registered.");
         } else {
            auto &node = pos->second;
            if (!MatchAddr(node->addrs_, SrcAddr(head))) {
               return MakeReply<Reply>(eAddressNotMatch, "Node address error.");
            } else if (head.type() == kMsgTypeHeartbeat && CanHeartbeat(*node)) {
               return op(node);
            } else if (!Valid(*node)) {
               return MakeReply<Reply>(eNoRespond, "Node is not alive.");
            } else {
               return op(node);
            }
         }
      } catch (...) {
         //TODO error log
         return MakeReply<Reply>(eError, "internal error.");
      }
   }
   template <class Func>
   inline MsgCommonReply HandleMsg(const BHMsgHead &head, Func const &op)
   {
      return HandleMsg<MsgCommonReply, Func>(head, op);
   }
   MsgCommonReply Unregister(const BHMsgHead &head, MsgUnregister &msg)
   {
      return HandleMsg(
          head, [&](Node node) -> MsgCommonReply {
             NodeInfo &ni = *node;
             ni.state_.PutOffline(offline_time_);
             return MakeReply(eSuccess);
          });
   }
   MsgCommonReply RegisterRPC(const BHMsgHead &head, MsgRegisterRPC &msg)
   {
      return HandleMsg(
          head, [&](Node node) -> MsgCommonReply {
             auto src = SrcAddr(head);
             auto &topics = msg.topics().topic_list();
             node->services_[src].insert(topics.begin(), topics.end());
             TopicDest dest = {src, SrcAbsAddr(head), node};
             for (auto &topic : topics) {
                service_map_[topic].insert(dest);
             }
             LOG_DEBUG() << "node " << node->proc_.proc_id() << " ssn " << node->addrs_.begin()->first << " serve " << topics.size() << " topics:\n";
             for (auto &topic : topics) {
                LOG_DEBUG() << "\t" << topic;
             }
             return MakeReply(eSuccess);
          });
   }
   MsgCommonReply Heartbeat(const BHMsgHead &head, const MsgHeartbeat &msg)
   {
      return HandleMsg(head, [&](Node node) {
         NodeInfo &ni = *node;
         ni.state_.timestamp_ = head.timestamp();
         ni.state_.UpdateState(NowSec(), offline_time_, kill_time_);
         auto &info = msg.proc();
         if (!info.public_info().empty()) {
            ni.proc_.set_public_info(info.public_info());
         }
         if (!info.private_info().empty()) {
            ni.proc_.set_private_info(info.private_info());
         }
         return MakeReply(eSuccess);
      });
   }
   MsgQueryProcReply QueryProc(const BHMsgHead &head, const MsgQueryProc &req)
   {
      typedef MsgQueryProcReply Reply;
      auto query = [&](Node self) -> Reply {
         auto Add1 = [](Reply &reply, Node node) {
            auto info = reply.add_proc_list();
            *info->mutable_proc() = node->proc_;
            info->set_online(node->state_.flag_ == kStateNormal);
            for (auto &addr_topics : node->services_) {
               for (auto &topic : addr_topics.second) {
                  info->mutable_topics()->add_topic_list(topic);
               }
            }
         };
         if (!req.proc_id().empty()) {
            auto pos = online_node_addr_map_.find(req.proc_id());
            if (pos == online_node_addr_map_.end()) {
               return MakeReply<Reply>(eNotFound, "proc not found.");
            } else {
               auto node_pos = nodes_.find(pos->second);
               if (node_pos == nodes_.end()) {
                  return MakeReply<Reply>(eNotFound, "proc node not found.");
               } else {
                  auto reply = MakeReply<Reply>(eSuccess);
                  Add1(reply, node_pos->second);
                  return reply;
               }
            }
         } else {
            Reply reply(MakeReply<Reply>(eSuccess));
            for (auto &kv : nodes_) {
               Add1(reply, kv.second);
            }
            return reply;
         }
      };
      return HandleMsg<Reply>(head, query);
   }
   MsgQueryTopicReply QueryTopic(const BHMsgHead &head, const MsgQueryTopic &req)
   {
      typedef MsgQueryTopicReply Reply;
      auto query = [&](Node self) -> Reply {
         auto pos = service_map_.find(req.topic());
         if (pos != service_map_.end() && !pos->second.empty()) {
            auto &clients = pos->second;
            Reply reply = MakeReply<Reply>(eSuccess);
            for (auto &dest : clients) {
               Node dest_node(dest.weak_node_.lock());
               if (dest_node && Valid(*dest_node)) {
                  auto node_addr = reply.add_node_address();
                  node_addr->set_proc_id(dest_node->proc_.proc_id());
                  node_addr->mutable_addr()->set_mq_id(dest.mq_id_);
                  node_addr->mutable_addr()->set_abs_addr(dest.mq_abs_addr_);
               }
            }
            return reply;
         } else {
            return MakeReply<Reply>(eNotFound, "topic server not found.");
         }
      };
      return HandleMsg<Reply>(head, query);
   }
   MsgCommonReply Subscribe(const BHMsgHead &head, const MsgSubscribe &msg)
   {
      return HandleMsg(head, [&](Node node) {
         auto src = SrcAddr(head);
         auto &topics = msg.topics().topic_list();
         node->subscriptions_[src].insert(topics.begin(), topics.end());
         TopicDest dest = {src, SrcAbsAddr(head), node};
         for (auto &topic : topics) {
            subscribe_map_[topic].insert(dest);
         }
         return MakeReply(eSuccess);
      });
   }
   MsgCommonReply Unsubscribe(const BHMsgHead &head, const MsgUnsubscribe &msg)
   {
      return HandleMsg(head, [&](Node node) {
         auto src = SrcAddr(head);
         auto pos = node->subscriptions_.find(src);
         auto RemoveSubTopicDestRecord = [this](const Topic &topic, const TopicDest &dest) {
            auto pos = subscribe_map_.find(topic);
            if (pos != subscribe_map_.end() &&
                pos->second.erase(dest) != 0 &&
                pos->second.empty()) {
               subscribe_map_.erase(pos);
            }
         };
         if (pos != node->subscriptions_.end()) {
            const TopicDest &dest = {src, SrcAbsAddr(head), node};
            auto &topics = msg.topics().topic_list();
            // clear node sub records;
            for (auto &topic : topics) {
               pos->second.erase(topic);
               RemoveSubTopicDestRecord(topic, dest);
            }
            if (pos->second.empty()) {
               node->subscriptions_.erase(pos);
            }
         }
         return MakeReply(eSuccess);
      });
   }
   Clients DoFindClients(const std::string &topic)
   {
      Clients dests;
      auto Find1 = [&](const std::string &t) {
         auto pos = subscribe_map_.find(topic);
         if (pos != subscribe_map_.end()) {
            auto &clients = pos->second;
            for (auto &cli : clients) {
               if (Valid(cli.weak_node_)) {
                  dests.insert(cli);
               }
            }
         }
      };
      Find1(topic);
      size_t pos = 0;
      while (true) {
         pos = topic.find(kTopicSep, pos);
         if (pos == topic.npos || ++pos == topic.size()) {
            // Find1(std::string()); // sub all.
            break;
         } else {
            Find1(topic.substr(0, pos));
         }
      }
      return dests;
   }
   bool FindClients(const BHMsgHead &head, const MsgPublish &msg, Clients &out, MsgCommonReply &reply)
   {
      bool ret = false;
      HandleMsg(head, [&](Node node) {
         DoFindClients(msg.topic()).swap(out);
         ret = true;
         return MakeReply(eSuccess);
      }).Swap(&reply);
      return ret;
   }
   void OnTimer()
   {
      CheckNodes();
      msgs_.AutoRemove();
   }
private:
   void CheckNodes()
   {
      auto now = NowSec();
      if (now <= last_check_time_) { return; }
      last_check_time_ = now;
      auto it = nodes_.begin();
      while (it != nodes_.end()) {
         auto &cli = *it->second;
         cli.state_.UpdateState(now, offline_time_, kill_time_);
         if (cli.state_.flag_ == kStateKillme) {
            RemoveNode(it->second);
            it = nodes_.erase(it);
         } else {
            ++it;
         }
      }
      msgs_.DebugPrint();
   }
   bool CanHeartbeat(const NodeInfo &node)
   {
      return Valid(node) || node.state_.flag_ == kStateOffline;
   }
   bool Valid(const NodeInfo &node)
   {
      return node.state_.flag_ == kStateNormal;
   }
   bool Valid(const WeakNode &weak)
   {
      auto node = weak.lock();
      return node && Valid(*node);
   }
   void RemoveNode(Node &node)
   {
      auto EraseMapRec = [&node](auto &rec_map, auto &node_rec) {
         for (auto &addr_topics : node_rec) {
            TopicDest dest{addr_topics.first, 0, node}; // abs_addr is not used.
            for (auto &topic : addr_topics.second) {
               auto pos = rec_map.find(topic);
               if (pos != rec_map.end()) {
                  pos->second.erase(dest);
                  if (pos->second.empty()) {
                     rec_map.erase(pos);
                  }
               }
            }
         }
      };
      EraseMapRec(service_map_, node->services_);
      EraseMapRec(subscribe_map_, node->subscriptions_);
      // remove online record.
      auto pos = online_node_addr_map_.find(node->proc_.proc_id());
      if (pos != online_node_addr_map_.end()) {
         if (node->addrs_.find(pos->second) != node->addrs_.end()) {
            online_node_addr_map_.erase(pos);
         }
      }
      for (auto &addr : node->addrs_) {
         cleaner_(addr.first);
      }
      node->addrs_.clear();
   }
   std::string id_; // center proc id;
   std::unordered_map<Topic, Clients> service_map_;
   std::unordered_map<Topic, Clients> subscribe_map_;
   std::unordered_map<Address, Node> nodes_;
   std::unordered_map<ProcId, Address> online_node_addr_map_;
   ProcRecords procs_; // To get a short index for msg alloc.
   MsgRecords msgs_;   // record all msgs alloced.
   Cleaner cleaner_; // remove mqs.
   int64_t offline_time_;
   int64_t kill_time_;
   int64_t last_check_time_;
};
template <class Body, class OnMsg, class Replyer>
inline void Dispatch(MsgI &msg, BHMsgHead &head, OnMsg const &onmsg, Replyer const &replyer)
@@ -863,7 +177,7 @@
      }
   };
   auto nsec = seconds(NodeTimeoutSec());
   auto nsec = NodeTimeoutSec();
   auto center_ptr = std::make_shared<Synced<NodeCenter>>("#bhome_center", gc, nsec, nsec * 3); // *3 to allow other clients to finish sending msgs.
   AddCenter(center_ptr);