lichao
2021-04-08 c338820e4db43ad32c20ff429a038b06bcb980f8
src/center.cpp
@@ -16,20 +16,387 @@
 * =====================================================================================
 */
#include "center.h"
#include "bh_util.h"
#include "defs.h"
#include "pubsub_center.h"
#include "reqrep_center.h"
#include "shm.h"
#include <set>
using namespace bhome_shm;
using namespace bhome_msg;
using namespace bhome::msg;
typedef BHCenter::MsgHandler Handler;
Handler Join(Handler h1, Handler h2)
namespace
{
   return [h1, h2](ShmSocket &socket, bhome_msg::MsgI &imsg, bhome::msg::BHMsg &msg) {
      return h1(socket, imsg, msg) || h2(socket, imsg, msg);
auto Now = []() { time_t t; return time(&t); };
//TODO check proc_id
class NodeCenter
{
public:
   typedef std::string ProcId;
   typedef std::string Address;
   typedef bhome::msg::ProcInfo ProcInfo;
private:
   enum {
      kStateInvalid = 0,
      kStateNormal = 1,
      kStateNoRespond = 2,
      kStateOffline = 3,
   };
   struct ProcState {
      time_t timestamp_ = 0;
      uint32_t flag_ = 0; // reserved
   };
   typedef std::unordered_map<Address, std::set<Topic>> AddressTopics;
   struct NodeInfo {
      ProcState state_;             // state
      Address addr_;                // registered_mqid.
      ProcInfo proc_;               //
      AddressTopics services_;      // address: topics
      AddressTopics subscriptions_; // address: topics
   };
   typedef std::shared_ptr<NodeInfo> Node;
   typedef std::weak_ptr<NodeInfo> WeakNode;
   struct TopicDest {
      Address mq_;
      WeakNode weak_node_;
      bool operator<(const TopicDest &a) const { return mq_ < a.mq_; }
   };
   const std::string &SrcAddr(const BHMsgHead &head) { return head.route(0).mq_id(); }
public:
   typedef std::set<TopicDest> Clients;
   NodeCenter(const std::string &id = "#Center") :
       id_(id) {}
   const std::string &id() const { return id_; } // no need to lock.
   //TODO maybe just return serialized string.
   MsgCommonReply Register(const BHMsgHead &head, MsgRegister &msg)
   {
      if (msg.proc().proc_id() != head.proc_id()) {
         return MakeReply(eInvalidInput, "invalid proc id.");
      }
      try {
         Node node(new NodeInfo);
         node->addr_ = SrcAddr(head);
         node->proc_.Swap(msg.mutable_proc());
         node->state_.timestamp_ = Now();
         node->state_.flag_ = kStateNormal;
         nodes_[node->proc_.proc_id()] = node;
         return MakeReply(eSuccess);
      } catch (...) {
         return MakeReply(eError, "register node error.");
      }
   }
   template <class OnSuccess, class OnError>
   auto HandleMsg(const BHMsgHead &head, OnSuccess onOk, OnError onErr)
   {
      auto pos = nodes_.find(head.proc_id());
      if (pos == nodes_.end()) {
         return onErr(eNotRegistered, "Node is not registered.");
      } else {
         auto node = pos->second;
         if (head.type() == kMsgTypeHeartbeat && node->addr_ != SrcAddr(head)) {
            return onErr(eAddressNotMatch, "Node address error.");
         } else if (!Valid(*node)) {
            return onErr(eNoRespond, "Node is not alive.");
         } else {
            return onOk(node);
         }
      }
   }
   template <class Reply, class Func>
   Reply HandleMsg(const BHMsgHead &head, Func const &op)
   {
      try {
         auto onErr = [](const ErrorCode ec, const std::string &str) { return MakeReply<Reply>(ec, str); };
         return HandleMsg(head, op, onErr);
         auto pos = nodes_.find(head.proc_id());
         if (pos == nodes_.end()) {
            return MakeReply<Reply>(eNotRegistered, "Node is not registered.");
         } else {
            auto node = pos->second;
            if (node->addr_ != SrcAddr(head)) {
               return MakeReply<Reply>(eAddressNotMatch, "Node address error.");
            } else if (!Valid(*node)) {
               return MakeReply<Reply>(eNoRespond, "Node is not alive.");
            } else {
               return op(node);
            }
         }
      } catch (...) {
         //TODO error log
         return MakeReply<Reply>(eError, "internal error.");
      }
   }
   template <class Func>
   inline MsgCommonReply HandleMsg(const BHMsgHead &head, Func const &op)
   {
      return HandleMsg<MsgCommonReply, Func>(head, op);
   }
   MsgCommonReply RegisterRPC(const BHMsgHead &head, MsgRegisterRPC &msg)
   {
      return HandleMsg(
          head, [&](Node node) -> MsgCommonReply {
             auto &src = SrcAddr(head);
             node->services_[src].insert(msg.topics().begin(), msg.topics().end());
             TopicDest dest = {src, node};
             for (auto &topic : msg.topics()) {
                service_map_[topic].insert(dest);
             }
             return MakeReply(eSuccess);
          });
   }
   MsgCommonReply Heartbeat(const BHMsgHead &head, const MsgHeartbeat &msg)
   {
      return HandleMsg(head, [&](Node node) {
         NodeInfo &ni = *node;
         ni.state_.timestamp_ = Now();
         auto &info = msg.proc();
         if (!info.public_info().empty()) {
            ni.proc_.set_public_info(info.public_info());
         }
         if (!info.private_info().empty()) {
            ni.proc_.set_private_info(info.private_info());
         }
         return MakeReply(eSuccess);
      });
   }
   MsgQueryTopicReply QueryTopic(const BHMsgHead &head, const MsgQueryTopic &req)
   {
      typedef MsgQueryTopicReply Reply;
      auto query = [&](Node self) -> MsgQueryTopicReply {
         auto pos = service_map_.find(req.topic());
         if (pos != service_map_.end() && !pos->second.empty()) {
            // now just find first one.
            const TopicDest &dest = *(pos->second.begin());
            Node dest_node(dest.weak_node_.lock());
            if (!dest_node) {
               service_map_.erase(pos);
               return MakeReply<Reply>(eOffline, "topic server offline.");
            } else if (!Valid(*dest_node)) {
               return MakeReply<Reply>(eNoRespond, "topic server not responding.");
            } else {
               MsgQueryTopicReply reply = MakeReply<Reply>(eSuccess);
               reply.mutable_address()->set_mq_id(dest.mq_);
               return reply;
            }
         } else {
            return MakeReply<Reply>(eNotFound, "topic server not found.");
         }
      };
      return HandleMsg<Reply>(head, query);
   }
   MsgCommonReply Subscribe(const BHMsgHead &head, const MsgSubscribe &msg)
   {
      return HandleMsg(head, [&](Node node) {
         auto &src = SrcAddr(head);
         node->subscriptions_[src].insert(msg.topics().begin(), msg.topics().end());
         TopicDest dest = {src, node};
         for (auto &topic : msg.topics()) {
            subscribe_map_[topic].insert(dest);
         }
         return MakeReply(eSuccess);
      });
   }
   MsgCommonReply Unsubscribe(const BHMsgHead &head, const MsgUnsubscribe &msg)
   {
      return HandleMsg(head, [&](Node node) {
         auto &src = SrcAddr(head);
         auto pos = node->subscriptions_.find(src);
         auto RemoveSubTopicDestRecord = [this](const Topic &topic, const TopicDest &dest) {
            auto pos = subscribe_map_.find(topic);
            if (pos != subscribe_map_.end() &&
                pos->second.erase(dest) != 0 &&
                pos->second.empty()) {
               subscribe_map_.erase(pos);
            }
         };
         if (pos != node->subscriptions_.end()) {
            const TopicDest &dest = {src, node};
            // clear node sub records;
            for (auto &topic : msg.topics()) {
               pos->second.erase(topic);
               RemoveSubTopicDestRecord(topic, dest);
            }
            if (pos->second.empty()) {
               node->subscriptions_.erase(pos);
            }
         }
         return MakeReply(eSuccess);
      });
   }
   Clients DoFindClients(const std::string &topic)
   {
      Clients dests;
      auto Find1 = [&](const std::string &t) {
         auto pos = subscribe_map_.find(topic);
         if (pos != subscribe_map_.end()) {
            auto &clients = pos->second;
            for (auto &cli : clients) {
               if (Valid(cli.weak_node_)) {
                  dests.insert(cli);
               }
            }
         }
      };
      Find1(topic);
      size_t pos = 0;
      while (true) {
         pos = topic.find(kTopicSep, pos);
         if (pos == topic.npos || ++pos == topic.size()) {
            // Find1(std::string()); // sub all.
            break;
         } else {
            Find1(topic.substr(0, pos));
         }
      }
      return dests;
   }
   bool FindClients(const BHMsgHead &head, const MsgPublish &msg, Clients &out, MsgCommonReply &reply)
   {
      bool ret = false;
      HandleMsg(head, [&](Node node) {
         DoFindClients(msg.topic()).swap(out);
         ret = true;
         return MakeReply(eSuccess);
      }).Swap(&reply);
      return ret;
   }
private:
   bool Valid(const NodeInfo &node)
   {
      return node.state_.flag_ == kStateNormal;
   }
   bool Valid(const WeakNode &weak)
   {
      auto node = weak.lock();
      return node && Valid(*node);
   }
   void CheckAllNodes(); //TODO, call it in timer.
   std::string id_;      // center proc id;
   std::unordered_map<Topic, Clients> service_map_;
   std::unordered_map<Topic, Clients> subscribe_map_;
   std::unordered_map<ProcId, Node> nodes_;
};
template <class Body, class OnMsg, class Replyer>
inline void Dispatch(MsgI &msg, BHMsgHead &head, OnMsg const &onmsg, Replyer const &replyer)
{
   if (head.route_size() != 1) { return; }
   Body body;
   if (msg.ParseBody(body)) {
      replyer(onmsg(body));
   }
}
Handler Combine(const Handler &h1, const Handler &h2)
{
   return [h1, h2](ShmSocket &socket, bhome_msg::MsgI &msg, bhome::msg::BHMsgHead &head) {
      return h1(socket, msg, head) || h2(socket, msg, head);
   };
}
template <class... H>
Handler Combine(const Handler &h0, const Handler &h1, const Handler &h2, const H &...rest)
{
   return Combine(Combine(h0, h1), h2, rest...);
}
#define CASE_ON_MSG_TYPE(MsgTag)                                                         \
   case kMsgType##MsgTag:                                                               \
      Dispatch<Msg##MsgTag>(                                                           \
          msg, head, [&](auto &body) { return center->MsgTag(head, body); }, replyer); \
      return true;
bool InstallCenter()
{
   auto center_ptr = std::make_shared<Synced<NodeCenter>>();
   auto MakeReplyer = [](ShmSocket &socket, BHMsgHead &head, const std::string &proc_id) {
      return [&](auto &&rep_body) {
         auto reply_head(InitMsgHead(GetType(rep_body), proc_id, head.msg_id()));
         bool r = socket.Send(head.route(0).mq_id().data(), reply_head, rep_body, 10);
         if (!r) {
            printf("send reply failed.\n");
         }
         //TODO resend failed.
      };
   };
   auto OnCenter = [=](ShmSocket &socket, MsgI &msg, BHMsgHead &head) -> bool {
      auto &center = *center_ptr;
      auto replyer = MakeReplyer(socket, head, center->id());
      switch (head.type()) {
         CASE_ON_MSG_TYPE(Register);
         CASE_ON_MSG_TYPE(Heartbeat);
         CASE_ON_MSG_TYPE(RegisterRPC);
         CASE_ON_MSG_TYPE(QueryTopic);
      default: return false;
      }
   };
   auto OnPubSub = [=](ShmSocket &socket, MsgI &msg, BHMsgHead &head) -> bool {
      auto &center = *center_ptr;
      auto replyer = MakeReplyer(socket, head, center->id());
      auto OnPublish = [&]() {
         MsgPublish pub;
         NodeCenter::Clients clients;
         MsgCommonReply reply;
         MsgI pubmsg;
         if (head.route_size() != 1 || !msg.ParseBody(pub)) {
            return;
         } else if (!center->FindClients(head, pub, clients, reply)) {
            // send error reply.
            MakeReplyer(socket, head, center->id())(reply);
         } else if (pubmsg.MakeRC(socket.shm(), msg)) {
            DEFER1(pubmsg.Release(socket.shm()));
            for (auto &cli : clients) {
               auto node = cli.weak_node_.lock();
               if (node) {
                  socket.Send(cli.mq_.data(), pubmsg, 10);
               }
            }
         }
      };
      switch (head.type()) {
         CASE_ON_MSG_TYPE(Subscribe);
         CASE_ON_MSG_TYPE(Unsubscribe);
      case kMsgTypePublish: OnPublish(); return true;
      default: return false;
      }
   };
   BHCenter::Install("#center.reg", OnCenter, BHTopicCenterAddress(), 1000);
   BHCenter::Install("#center.bus", OnPubSub, BHTopicBusAddress(), 1000);
   return true;
}
#undef CASE_ON_MSG_TYPE
} // namespace
SharedMemory &BHomeShm()
{
@@ -42,17 +409,24 @@
   static CenterRecords rec;
   return rec;
}
bool BHCenter::Install(const std::string &name, MsgHandler handler, const std::string &mqid, const int mq_len)
{
   CenterRecords()[name] = CenterInfo{name, handler, mqid, mq_len};
   Centers()[name] = CenterInfo{name, handler, mqid, mq_len};
   return true;
}
bool BHCenter::Install(const std::string &name, MsgHandler handler, const MQId &mqid, const int mq_len)
{
   return Install(name, handler, std::string((const char *) &mqid, sizeof(mqid)), mq_len);
}
BHCenter::BHCenter(Socket::Shm &shm)
{
   sockets_["center"] = std::make_shared<ShmSocket>(shm, &BHTopicCenterAddress(), 1000);
   sockets_["bus"] = std::make_shared<ShmSocket>(shm, &BHTopicBusAddress(), 1000);
   InstallCenter();
   for (auto &kv : Centers()) {
      sockets_[kv.first] = std::make_shared<ShmSocket>(shm, kv.second.mqid_.data(), kv.second.mq_len_);
      auto &info = kv.second;
      sockets_[info.name_] = std::make_shared<ShmSocket>(shm, *(MQId *) info.mqid_.data(), info.mq_len_);
   }
}
@@ -61,16 +435,12 @@
bool BHCenter::Start()
{
   auto onCenter = MakeReqRepCenter();
   auto onBus = MakeBusCenter();
   sockets_["center"]->Start(onCenter);
   sockets_["bus"]->Start(onBus);
   for (auto &kv : Centers()) {
      sockets_[kv.first]->Start(kv.second.handler_);
      auto &info = kv.second;
      sockets_[info.name_]->Start(info.handler_);
   }
   return true;
   // socket_.Start(Join(onCenter, onBus));
}
bool BHCenter::Stop()