zhangzengfei
2024-12-09 366e2ff546092d9be26411fb698b3ddd8e834a11
cache/shardmap/shardmap.go
@@ -29,7 +29,7 @@
var Count = make(chan int)
type wfOp func(a []byte, b string) float32
type wfOp func(a, b []float32) float32
/**
* @param uint8, shardCnt must be pow of two
@@ -90,42 +90,44 @@
   value interface{}
}
func (s *ShardMap) Walk(wf wfOp, sourceFea []byte, baseScore float32) (targets []*protomsg.SdkCompareEach) {
func (s *ShardMap) Walk(wf wfOp, sourceFea []float32, baseScore float32) (targets []*protomsg.SdkCompareEach) {
   var wg sync.WaitGroup
   var lock sync.Mutex
   for _, si := range s.shards {
      var tempsi shardItem = *si
   for _, si := range s.shards {
      var tempsi *shardItem = si // 保持对原始 shardItem 的指针引用
      // 跳过空分片
      if len(tempsi.data) == 0 {
         continue
      }
      wg.Add(1)
      go func(st *shardItem, fw wfOp, sf []byte, baseSec float32) {
      go func(st *shardItem, fn wfOp, srcFeat []float32, baseSec float32) {
         defer wg.Done()
         st.RLock()         // 锁定读取
         defer st.RUnlock() // 确保读取完毕后解锁
         for _, feature := range st.data {
            if eif, ok := feature.(*db.FeatureCacheBase); ok {
               score := float32(0)
               score = fw(sf, eif.FaceFeature)
            // 读取操作在锁内进行,防止并发冲突
            if item, ok := feature.(*db.FeatureCacheBase); ok {
               score := fn(srcFeat, item.FaceFeature)
               if score > 0 && score >= baseScore {
                  lock.Lock()
                  lock.Lock() // 保护目标切片的写入
                  targets = append(targets, &protomsg.SdkCompareEach{
                     Id:           eif.Id,
                     Id:           item.Id,
                     CompareScore: score,
                     Tableid:      eif.TableId,
                     Tableid:      item.TableId,
                  })
                  lock.Unlock()
               }
            }
         }
      }(&tempsi, wf, sourceFea, baseScore)
      }(tempsi, wf, sourceFea, baseScore)
   }
   wg.Wait()
   return targets
}
// print all