From d3e7f93e69cb24c766292d8780e745caf24d42a8 Mon Sep 17 00:00:00 2001
From: lichao <lichao@aiotlink.com>
Date: 星期四, 25 三月 2021 18:30:51 +0800
Subject: [PATCH] add ref count.
---
src/msg.h | 16 ++-----
utest/utest.cpp | 43 ++++++++++++++-------
src/shm_queue.cpp | 7 ++-
3 files changed, 38 insertions(+), 28 deletions(-)
diff --git a/src/msg.h b/src/msg.h
index ec17aaf..44c961f 100644
--- a/src/msg.h
+++ b/src/msg.h
@@ -59,7 +59,7 @@
int Get() { Guard lk(mutex_); return num_; }
private:
Mutex mutex_;
- int num_ = 0;
+ int num_ = 1;
};
class Msg {
@@ -67,13 +67,6 @@
offset_ptr<void> ptr_;
offset_ptr<RefCount> count_;
public:
- class CountGuard : private boost::noncopyable {
- Msg &msg_;
- public:
- CountGuard(Msg &msg) : msg_(msg) { msg_.AddRef(); }
- ~CountGuard() { msg_.RemoveRef(); }
- };
-
Msg(void *p=0, RefCount *c=0):ptr_(p), count_(c) {}
void swap(Msg &a) { std::swap(ptr_, a.ptr_); std::swap(count_, a.count_); }
@@ -84,9 +77,10 @@
// Msg & operator = (Msg &&a) { Msg(std::move(a)).swap(*this); }
template <class T = void> T *get() { return static_cast<T*>(ptr_.get()); }
- int AddRef() const { return count_ ? count_->Inc() : 0; }
- int RemoveRef() const{ return count_ ? count_->Dec() : 0; }
- int Count() const { return count_ ? count_->Get() : 0; }
+ int AddRef() const { return IsCounted() ? count_->Inc() : 1; }
+ int RemoveRef() const{ return IsCounted() ? count_->Dec() : 0; }
+ int Count() const{ return IsCounted() ? count_->Get() : 1; }
+ bool IsCounted() const { return static_cast<bool>(count_); }
bool Build(SharedMemory &shm, const MQId &src_id, const void *p, const size_t size, const bool refcount);
void FreeFrom(SharedMemory &shm);
};
diff --git a/src/shm_queue.cpp b/src/shm_queue.cpp
index f770afc..77add97 100644
--- a/src/shm_queue.cpp
+++ b/src/shm_queue.cpp
@@ -73,7 +73,9 @@
if (Send(remote_id, msg, timeout_ms)) {
return true;
} else {
- msg.FreeFrom(shm());
+ if (msg.RemoveRef() == 0) { // works for both refcounted and not counted.
+ msg.FreeFrom(shm());
+ }
}
}
return false;
@@ -83,9 +85,10 @@
{
Msg msg;
if (Read(msg, timeout_ms)) {
+ DEFER1(if (msg.RemoveRef() == 0) { msg.FreeFrom(shm()); });
+
auto ptr = msg.get<char>();
if (ptr) {
- DEFER1(shm().Dealloc(ptr););
MsgMetaV1 meta;
meta.Parse(ptr);
source_id = meta.src_id_;
diff --git a/utest/utest.cpp b/utest/utest.cpp
index 61e6437..6994cbd 100644
--- a/utest/utest.cpp
+++ b/utest/utest.cpp
@@ -145,21 +145,15 @@
SharedMemory shm(shm_name, 1024*1024);
Msg m0(shm.Alloc(1000), shm.New<RefCount>());
- BOOST_CHECK_EQUAL(m0.AddRef(), 1);
+ BOOST_CHECK(m0.IsCounted());
+ BOOST_CHECK_EQUAL(m0.Count(), 1);
Msg m1 = m0;
+ BOOST_CHECK(m1.IsCounted());
BOOST_CHECK_EQUAL(m1.AddRef(), 2);
BOOST_CHECK_EQUAL(m0.AddRef(), 3);
BOOST_CHECK_EQUAL(m0.RemoveRef(), 2);
BOOST_CHECK_EQUAL(m0.RemoveRef(), 1);
BOOST_CHECK_EQUAL(m1.RemoveRef(), 0);
- {
- Msg::CountGuard guard(m0);
- BOOST_CHECK_EQUAL(m1.AddRef(), 2);
- {
- Msg::CountGuard guard(m0);
- BOOST_CHECK_EQUAL(m1.RemoveRef(), 2);
- }
- }
BOOST_CHECK_EQUAL(m1.Count(), 0);
}
@@ -201,16 +195,23 @@
const size_t msg_length = 1000;
std::string msg_content(msg_length, 'a');
msg_content[20] = '\0';
+ Msg request;
+ request.Build(shm, cli.Id(), msg_content.data(), msg_content.size(), true);
+ Msg reply(request);
+
std::atomic<bool> stop(false);
std::atomic<uint64_t> count(0);
using namespace boost::posix_time;
auto Now = []() { return second_clock::universal_time(); };
- std::atomic<ptime> last_time(Now());
+ std::atomic<ptime> last_time(Now() - seconds(1));
std::atomic<uint64_t> last_count(0);
auto Client = [&](int tid, int nmsg){
for (int i = 0; i < nmsg; ++i) {
- if (!cli.Send(srv.Id(), msg_content.data(), msg_content.size(), 1000)) {
+ auto Send = [&]() { return cli.Send(srv.Id(), msg_content.data(), msg_content.size(), 1000); };
+ auto SendRefCounted = [&]() { return cli.Send(srv.Id(), request, 1000); };
+
+ if (!Send()) {
printf("********** client send error.\n");
continue;
}
@@ -228,7 +229,8 @@
auto cur = Now();
if (last_time.exchange(cur) != cur) {
std::cout << "time: " << Now();
- printf(", total msg:%10ld, speed:%8ld/s, used mem:%8ld\n", count.load(), count - last_count.exchange(count), init_avail - Avail());
+ printf(", total msg:%10ld, speed:%8ld/s, used mem:%8ld, refcount:%d\n",
+ count.load(), count - last_count.exchange(count), init_avail - Avail(), request.Count());
last_time = cur;
}
@@ -243,7 +245,10 @@
while (!stop) {
if (srv.Recv(src_id, data, size, 100)) {
DEFER1(free(data));
- if (srv.Send(src_id, data, size, 100)) {
+ auto Send = [&](){ return srv.Send(src_id, data, size, 100); };
+ auto SendRefCounted = [&](){ return srv.Send(src_id, reply, 100); };
+
+ if (SendRefCounted()) {
if (size != msg_content.size()) {
BOOST_TEST(false, "server msg size error");
}
@@ -257,14 +262,22 @@
ThreadManager clients, servers;
for (int i = 0; i < qlen; ++i) { servers.Launch(Server); }
- int ncli = 100;
- uint64_t nmsg = 1000*10;
+ int ncli = 100*1;
+ uint64_t nmsg = 100*100;
printf("client threads: %d, msgs : %ld, total msg: %ld\n", ncli, nmsg, ncli * nmsg);
for (int i = 0; i < ncli; ++i) { clients.Launch(Client, i, nmsg); }
clients.WaitAll();
printf("request ok: %ld\n", count.load());
stop = true;
servers.WaitAll();
+ BOOST_CHECK(request.IsCounted());
+ BOOST_CHECK_EQUAL(request.Count(), 1);
+ BOOST_CHECK(reply.IsCounted());
+ BOOST_CHECK_EQUAL(reply.Count(), 1);
+ if (request.RemoveRef() == 0) {
+ BOOST_CHECK_EQUAL(reply.Count(), 0);
+ request.FreeFrom(shm);
+ }
}
int test_main(int argc, char *argv[])
--
Gitblit v1.8.0