From bdcc9624341ee34298be74a706b09f12f8306456 Mon Sep 17 00:00:00 2001
From: zhangzengfei <zhangzengfei@smartai.com>
Date: 星期四, 18 四月 2024 23:16:56 +0800
Subject: [PATCH] 优化缓存的数据, 取消多次的base64计算和float32转换

---
 db/base.go                 |    2 
 compare/compare.go         |   16 +++-----
 cache/shardmap/shardmap.go |   14 +++---
 compare/faceSdk.go         |   31 +++++++++++++--
 db/person.go               |   16 +++++++-
 5 files changed, 54 insertions(+), 25 deletions(-)

diff --git a/cache/shardmap/shardmap.go b/cache/shardmap/shardmap.go
index 255f0c4..daf5f85 100644
--- a/cache/shardmap/shardmap.go
+++ b/cache/shardmap/shardmap.go
@@ -29,7 +29,7 @@
 
 var Count = make(chan int)
 
-type wfOp func(a []byte, b string) float32
+type wfOp func(a, b []float32) float32
 
 /**
 * @param uint8, shardCnt must be pow of two
@@ -90,7 +90,7 @@
 	value interface{}
 }
 
-func (s *ShardMap) Walk(wf wfOp, sourceFea []byte, baseScore float32) (targets []*protomsg.SdkCompareEach) {
+func (s *ShardMap) Walk(wf wfOp, sourceFea []float32, baseScore float32) (targets []*protomsg.SdkCompareEach) {
 	var wg sync.WaitGroup
 	var lock sync.Mutex
 	for _, si := range s.shards {
@@ -102,18 +102,18 @@
 
 		wg.Add(1)
 
-		go func(st *shardItem, fw wfOp, sf []byte, baseSec float32) {
+		go func(st *shardItem, fn wfOp, srcFeat []float32, baseSec float32) {
 			defer wg.Done()
 			for _, feature := range st.data {
-				if eif, ok := feature.(*db.FeatureCacheBase); ok {
+				if item, ok := feature.(*db.FeatureCacheBase); ok {
 					score := float32(0)
-					score = fw(sf, eif.FaceFeature)
+					score = fn(srcFeat, item.FaceFeature)
 					if score > 0 && score >= baseScore {
 						lock.Lock()
 						targets = append(targets, &protomsg.SdkCompareEach{
-							Id:           eif.Id,
+							Id:           item.Id,
 							CompareScore: score,
-							Tableid:      eif.TableId,
+							Tableid:      item.TableId,
 						})
 						lock.Unlock()
 					}
diff --git a/compare/compare.go b/compare/compare.go
index 5b86ce6..2fc28b5 100644
--- a/compare/compare.go
+++ b/compare/compare.go
@@ -1,7 +1,6 @@
 package compare
 
 import (
-	"encoding/base64"
 	"fmt"
 	"strconv"
 
@@ -18,6 +17,8 @@
 	if args.FaceFeature == nil {
 		return nil
 	}
+
+	floatFeat := ByteSlice2float32Slice(args.FaceFeature)
 
 	//鎸囧畾鏈�浣庡垎
 	baseScore := thresholdLimit
@@ -36,7 +37,7 @@
 	if args.TreeNodes != nil && len(args.TreeNodes) > 0 {
 		for _, id := range args.TreeNodes {
 			if _, ok := cache.CacheMap.Area[id]; ok {
-				targets := cache.CacheMap.Area[id].Walk(DoSdkCompare, args.FaceFeature, baseScore)
+				targets := cache.CacheMap.Area[id].Walk(DoSdkCompare, floatFeat, baseScore)
 				if len(targets) > 0 {
 					scResult.CompareResult = append(scResult.CompareResult, targets...)
 				}
@@ -60,7 +61,7 @@
 			continue
 		}
 
-		targets := val.Walk(DoSdkCompare, args.FaceFeature, baseScore)
+		targets := val.Walk(DoSdkCompare, floatFeat, baseScore)
 		if len(targets) > 0 {
 			scResult.CompareResult = append(scResult.CompareResult, targets...)
 			// todo 娣诲姞灏忓尯澶栫殑鍏宠仈鍏崇郴, 涓嬫浼樺厛姣斿
@@ -78,13 +79,8 @@
 	return buf
 }
 
-func DoSdkCompare(ci []byte, co string) float32 {
-	co_d, err := base64.StdEncoding.DecodeString(co)
-	if err != nil {
-		logger.Error("DoSdkCompare err:", err)
-		return -1
-	}
-	sec := DecCompare(ci, co_d)
+func DoSdkCompare(ci, co []float32) float32 {
+	sec := DirectCompare(ci, co)
 	//logger.Debug("姣斿寰楀垎涓猴細", sec)
 
 	sec = ParseScore(sec)
diff --git a/compare/faceSdk.go b/compare/faceSdk.go
index 6ffcf39..49fa258 100644
--- a/compare/faceSdk.go
+++ b/compare/faceSdk.go
@@ -62,10 +62,31 @@
 
 //	    return  fscore;
 //	}
-func DecCompare(feat1 []byte, feat2 []byte) float32 {
-	ffeat1 := byteSlice2float32Slice(feat1)
-	ffeat2 := byteSlice2float32Slice(feat2)
-	if len(ffeat1) != len(ffeat2) {
+func DirectCompare(feat1 []float32, feat2 []float32) float32 {
+	if len(feat1) != len(feat2) {
+		return 0
+	}
+
+	var score float32
+	for i := 0; i < 1536; i++ {
+		score += feat1[i] * feat2[i]
+	}
+	score += 0.05
+	if score > 0.9999 {
+		score = 0.9999
+	}
+	if score < 0.0001 {
+		score = 0.0001
+	}
+
+	//fmt.Println("score:", score)
+	return score
+}
+
+func DecCompare(feat1, feat2 []byte) float32 {
+	ffeat1 := ByteSlice2float32Slice(feat1)
+	ffeat2 := ByteSlice2float32Slice(feat2)
+	if len(feat1) != len(feat2) {
 		return 0
 	}
 	//fmt.Println("len:", len(ffeat1), len(feat2))
@@ -86,7 +107,7 @@
 	return score
 }
 
-func byteSlice2float32Slice(src []byte) []float32 {
+func ByteSlice2float32Slice(src []byte) []float32 {
 	if len(src) == 0 {
 		return nil
 	}
diff --git a/db/base.go b/db/base.go
index 4f8c5e8..ed77248 100644
--- a/db/base.go
+++ b/db/base.go
@@ -13,6 +13,6 @@
 	Id          string
 	AreaId      string
 	TableId     string
-	FaceFeature string
+	FaceFeature []float32
 	Enable      int32
 }
diff --git a/db/person.go b/db/person.go
index 0a567e5..3224daa 100644
--- a/db/person.go
+++ b/db/person.go
@@ -1,6 +1,8 @@
 package db
 
 import (
+	"encoding/base64"
+	"sdkCompare/compare"
 	"strconv"
 )
 
@@ -55,11 +57,16 @@
 
 	for _, p := range persons {
 		if p.FaceFeature != "" {
+			byteFeat, err := base64.StdEncoding.DecodeString(p.FaceFeature)
+			if err != nil {
+				continue
+			}
+
 			arr = append(arr, &FeatureCacheBase{
 				Id:          p.Id,
 				TableId:     p.TableId,
 				AreaId:      p.AreaID,
-				FaceFeature: p.FaceFeature,
+				FaceFeature: compare.ByteSlice2float32Slice(byteFeat),
 				Enable:      int32(p.Enable),
 			})
 		}
@@ -74,12 +81,17 @@
 	if err != nil {
 		return nil, err
 	}
+
 	if p.FaceFeature != "" {
+		byteFeat, err := base64.StdEncoding.DecodeString(p.FaceFeature)
+		if err != nil {
+			return nil, err
+		}
 		info = &FeatureCacheBase{
 			Id:          p.Id,
 			TableId:     p.TableId,
 			AreaId:      p.AreaID,
-			FaceFeature: p.FaceFeature,
+			FaceFeature: compare.ByteSlice2float32Slice(byteFeat),
 			Enable:      int32(p.Enable),
 		}
 	}

--
Gitblit v1.8.0