From d3e7f93e69cb24c766292d8780e745caf24d42a8 Mon Sep 17 00:00:00 2001 From: lichao <lichao@aiotlink.com> Date: 星期四, 25 三月 2021 18:30:51 +0800 Subject: [PATCH] add ref count. --- src/msg.h | 16 ++----- utest/utest.cpp | 43 ++++++++++++++------- src/shm_queue.cpp | 7 ++- 3 files changed, 38 insertions(+), 28 deletions(-) diff --git a/src/msg.h b/src/msg.h index ec17aaf..44c961f 100644 --- a/src/msg.h +++ b/src/msg.h @@ -59,7 +59,7 @@ int Get() { Guard lk(mutex_); return num_; } private: Mutex mutex_; - int num_ = 0; + int num_ = 1; }; class Msg { @@ -67,13 +67,6 @@ offset_ptr<void> ptr_; offset_ptr<RefCount> count_; public: - class CountGuard : private boost::noncopyable { - Msg &msg_; - public: - CountGuard(Msg &msg) : msg_(msg) { msg_.AddRef(); } - ~CountGuard() { msg_.RemoveRef(); } - }; - Msg(void *p=0, RefCount *c=0):ptr_(p), count_(c) {} void swap(Msg &a) { std::swap(ptr_, a.ptr_); std::swap(count_, a.count_); } @@ -84,9 +77,10 @@ // Msg & operator = (Msg &&a) { Msg(std::move(a)).swap(*this); } template <class T = void> T *get() { return static_cast<T*>(ptr_.get()); } - int AddRef() const { return count_ ? count_->Inc() : 0; } - int RemoveRef() const{ return count_ ? count_->Dec() : 0; } - int Count() const { return count_ ? count_->Get() : 0; } + int AddRef() const { return IsCounted() ? count_->Inc() : 1; } + int RemoveRef() const{ return IsCounted() ? count_->Dec() : 0; } + int Count() const{ return IsCounted() ? count_->Get() : 1; } + bool IsCounted() const { return static_cast<bool>(count_); } bool Build(SharedMemory &shm, const MQId &src_id, const void *p, const size_t size, const bool refcount); void FreeFrom(SharedMemory &shm); }; diff --git a/src/shm_queue.cpp b/src/shm_queue.cpp index f770afc..77add97 100644 --- a/src/shm_queue.cpp +++ b/src/shm_queue.cpp @@ -73,7 +73,9 @@ if (Send(remote_id, msg, timeout_ms)) { return true; } else { - msg.FreeFrom(shm()); + if (msg.RemoveRef() == 0) { // works for both refcounted and not counted. + msg.FreeFrom(shm()); + } } } return false; @@ -83,9 +85,10 @@ { Msg msg; if (Read(msg, timeout_ms)) { + DEFER1(if (msg.RemoveRef() == 0) { msg.FreeFrom(shm()); }); + auto ptr = msg.get<char>(); if (ptr) { - DEFER1(shm().Dealloc(ptr);); MsgMetaV1 meta; meta.Parse(ptr); source_id = meta.src_id_; diff --git a/utest/utest.cpp b/utest/utest.cpp index 61e6437..6994cbd 100644 --- a/utest/utest.cpp +++ b/utest/utest.cpp @@ -145,21 +145,15 @@ SharedMemory shm(shm_name, 1024*1024); Msg m0(shm.Alloc(1000), shm.New<RefCount>()); - BOOST_CHECK_EQUAL(m0.AddRef(), 1); + BOOST_CHECK(m0.IsCounted()); + BOOST_CHECK_EQUAL(m0.Count(), 1); Msg m1 = m0; + BOOST_CHECK(m1.IsCounted()); BOOST_CHECK_EQUAL(m1.AddRef(), 2); BOOST_CHECK_EQUAL(m0.AddRef(), 3); BOOST_CHECK_EQUAL(m0.RemoveRef(), 2); BOOST_CHECK_EQUAL(m0.RemoveRef(), 1); BOOST_CHECK_EQUAL(m1.RemoveRef(), 0); - { - Msg::CountGuard guard(m0); - BOOST_CHECK_EQUAL(m1.AddRef(), 2); - { - Msg::CountGuard guard(m0); - BOOST_CHECK_EQUAL(m1.RemoveRef(), 2); - } - } BOOST_CHECK_EQUAL(m1.Count(), 0); } @@ -201,16 +195,23 @@ const size_t msg_length = 1000; std::string msg_content(msg_length, 'a'); msg_content[20] = '\0'; + Msg request; + request.Build(shm, cli.Id(), msg_content.data(), msg_content.size(), true); + Msg reply(request); + std::atomic<bool> stop(false); std::atomic<uint64_t> count(0); using namespace boost::posix_time; auto Now = []() { return second_clock::universal_time(); }; - std::atomic<ptime> last_time(Now()); + std::atomic<ptime> last_time(Now() - seconds(1)); std::atomic<uint64_t> last_count(0); auto Client = [&](int tid, int nmsg){ for (int i = 0; i < nmsg; ++i) { - if (!cli.Send(srv.Id(), msg_content.data(), msg_content.size(), 1000)) { + auto Send = [&]() { return cli.Send(srv.Id(), msg_content.data(), msg_content.size(), 1000); }; + auto SendRefCounted = [&]() { return cli.Send(srv.Id(), request, 1000); }; + + if (!Send()) { printf("********** client send error.\n"); continue; } @@ -228,7 +229,8 @@ auto cur = Now(); if (last_time.exchange(cur) != cur) { std::cout << "time: " << Now(); - printf(", total msg:%10ld, speed:%8ld/s, used mem:%8ld\n", count.load(), count - last_count.exchange(count), init_avail - Avail()); + printf(", total msg:%10ld, speed:%8ld/s, used mem:%8ld, refcount:%d\n", + count.load(), count - last_count.exchange(count), init_avail - Avail(), request.Count()); last_time = cur; } @@ -243,7 +245,10 @@ while (!stop) { if (srv.Recv(src_id, data, size, 100)) { DEFER1(free(data)); - if (srv.Send(src_id, data, size, 100)) { + auto Send = [&](){ return srv.Send(src_id, data, size, 100); }; + auto SendRefCounted = [&](){ return srv.Send(src_id, reply, 100); }; + + if (SendRefCounted()) { if (size != msg_content.size()) { BOOST_TEST(false, "server msg size error"); } @@ -257,14 +262,22 @@ ThreadManager clients, servers; for (int i = 0; i < qlen; ++i) { servers.Launch(Server); } - int ncli = 100; - uint64_t nmsg = 1000*10; + int ncli = 100*1; + uint64_t nmsg = 100*100; printf("client threads: %d, msgs : %ld, total msg: %ld\n", ncli, nmsg, ncli * nmsg); for (int i = 0; i < ncli; ++i) { clients.Launch(Client, i, nmsg); } clients.WaitAll(); printf("request ok: %ld\n", count.load()); stop = true; servers.WaitAll(); + BOOST_CHECK(request.IsCounted()); + BOOST_CHECK_EQUAL(request.Count(), 1); + BOOST_CHECK(reply.IsCounted()); + BOOST_CHECK_EQUAL(reply.Count(), 1); + if (request.RemoveRef() == 0) { + BOOST_CHECK_EQUAL(reply.Count(), 0); + request.FreeFrom(shm); + } } int test_main(int argc, char *argv[]) -- Gitblit v1.8.0