package main
|
|
/*
|
#cgo CFLAGS: -I${SRCDIR}/sdk/include -w -g
|
#cgo CXXFLAGS: -I${SRCDIR}/sdk/include -w -g -std=c++11
|
#cgo LDFLAGS: -L/usr/local/cuda-8.0/lib64 -L${SRCDIR}/sdk/lib
|
#cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/sdk/lib
|
#cgo LDFLAGS: -ldarknet -lcudart -lcublas -lcurand -lcudnn -lrt -ldl -lpthread
|
#include <stdlib.h>
|
#include "cyolo.h"
|
*/
|
import "C"
|
import (
|
"fmt"
|
"unsafe"
|
|
"basic.com/pubsub/protomsg.git"
|
"github.com/gogo/protobuf/proto"
|
)
|
|
// CPOINT pt
|
type CPOINT struct {
|
X int32
|
Y int32
|
}
|
|
// CRECT rc
|
type CRECT struct {
|
Left int32
|
Top int32
|
Right int32
|
Bottom int32
|
}
|
|
// CIMAGE img
|
type CIMAGE struct {
|
Data *uint8
|
Width int32
|
Height int32
|
Channel int32
|
Pad_cgo_0 [4]byte
|
}
|
|
// CObjInfo yolo
|
type CObjInfo struct {
|
RcObj CRECT
|
Typ int32
|
Prob float32
|
}
|
|
// CObjTrackInfo track yolo objs info
|
type CObjTrackInfo struct {
|
ObjInfo CObjInfo
|
ID uint64
|
}
|
|
type trackInfo struct {
|
lastTrackObjs []CObjTrackInfo
|
lastTrackID uint64
|
}
|
|
// YoloHandle wrap C
|
type YoloHandle struct {
|
handle unsafe.Pointer
|
tracker map[string]*trackInfo
|
}
|
|
// RatioInterTrack 跟踪判断重叠阈值
|
const RatioInterTrack = 50 //跟踪判断重叠阈值
|
|
// NewSDK init yolo sdk
|
func NewSDK(fc, fw, fn string, gi int) interface{} {
|
|
c := C.CString(fc)
|
defer C.free(unsafe.Pointer(c))
|
w := C.CString(fw)
|
defer C.free(unsafe.Pointer(w))
|
n := C.CString(fn)
|
defer C.free(unsafe.Pointer(n))
|
|
g := C.int(gi)
|
|
p := C.init(c, w, n, g)
|
return &YoloHandle{
|
handle: p,
|
tracker: make(map[string]*trackInfo),
|
}
|
}
|
|
// Free free
|
func Free(i interface{}) {
|
y := i.(*YoloHandle)
|
if y != nil {
|
if y.handle != nil {
|
C.release(y.handle)
|
}
|
}
|
}
|
|
// CYoloObjInfoArrayToGoArray convert cObjInfo array to go
|
func CYoloObjInfoArrayToGoArray(cArray unsafe.Pointer, count int) (goArray []CObjInfo) {
|
p := uintptr(cArray)
|
|
for i := 0; i < count; i++ {
|
j := *(*CObjInfo)(unsafe.Pointer(p))
|
goArray = append(goArray, j)
|
p += unsafe.Sizeof(j)
|
}
|
return
|
}
|
|
// YoloDetect yolo detect
|
func YoloDetect(y *YoloHandle, data []byte, w, h, c int, thrsh float32, umns int) []CObjInfo {
|
|
var count C.int
|
var cobjinfo unsafe.Pointer
|
|
ret := C.detect(y.handle,
|
unsafe.Pointer(&data[0]), C.int(w), C.int(h), C.int(c),
|
C.float(thrsh), C.int(umns),
|
&cobjinfo, &count)
|
|
if ret == 0 {
|
return CYoloObjInfoArrayToGoArray(unsafe.Pointer(cobjinfo), int(count))
|
}
|
return nil
|
}
|
|
// YoloObjName obj name by type
|
func YoloObjName(i interface{}, typ int) string {
|
y := i.(*YoloHandle)
|
p := C.obj_name_by_type(y.handle, C.int(typ))
|
|
return C.GoString(p)
|
}
|
|
func max(a, b int32) int32 {
|
if a < b {
|
return b
|
}
|
return a
|
}
|
|
func min(a, b int32) int32 {
|
if a < b {
|
return a
|
}
|
return b
|
}
|
|
func countInterAreaOfTwoRect(rect1 CRECT, rect2 CRECT) int32 {
|
xMin := min(rect1.Left, rect2.Left)
|
yMin := min(rect1.Top, rect2.Top)
|
xMax := max(rect1.Right, rect2.Right)
|
yMax := max(rect1.Bottom, rect2.Bottom)
|
|
wRect1 := rect1.Right - rect1.Left
|
hRect1 := rect1.Bottom - rect1.Top
|
|
wRect2 := rect2.Right - rect2.Left
|
hRect2 := rect2.Bottom - rect2.Top
|
|
wInter := wRect1 + wRect2 - (xMax - xMin)
|
hInter := hRect1 + hRect2 - (yMax - yMin)
|
|
if (wInter <= 0) || (hInter <= 0) {
|
return 0
|
}
|
|
areaInter := wInter * hInter
|
areaRect1 := wRect1 * hRect1
|
areaRect2 := wRect2 * hRect2
|
ratio := areaInter * 100 / min(areaRect1, areaRect2)
|
|
return ratio
|
}
|
|
// YoloDetectTrack2 yolo detect (只识别人)
|
func YoloDetectTrack2(y *YoloHandle, LastYoloObjs []CObjTrackInfo, LastTrackID *uint64, data []byte, w, h, c int, thrsh float32, umns int) (allObjs []CObjTrackInfo, newObjs []CObjTrackInfo) {
|
|
var tmp CObjTrackInfo
|
//LastYoloObjs
|
detectObjs := YoloDetect(y, data, w, h, c, thrsh, umns)
|
|
for _, vLast := range LastYoloObjs {
|
for i := 0; i < len(detectObjs); i++ {
|
//fmt.Println("vNew.Typ:", vNew.Typ)
|
if vLast.ObjInfo.Typ == detectObjs[i].Typ { //同一类别,比如都是人体
|
ratio := countInterAreaOfTwoRect(vLast.ObjInfo.RcObj, detectObjs[i].RcObj)
|
if ratio >= RatioInterTrack {
|
//update LastYoloObjs
|
vLast.ObjInfo.RcObj = detectObjs[i].RcObj
|
vLast.ObjInfo.Prob = detectObjs[i].Prob
|
|
allObjs = append(allObjs, vLast)
|
detectObjs = append(detectObjs[:i], detectObjs[i+1:]...) //从检测目标里删除已经查到的跟踪目标
|
i--
|
break //上一帧跟踪的目标已经找到,无需往下处理其他检测目标
|
}
|
}
|
}
|
}
|
|
//处理新出现的目标
|
id := *LastTrackID
|
if len(detectObjs) > 0 {
|
for _, vAdd := range detectObjs {
|
tmp.ObjInfo = vAdd
|
tmp.ID = id
|
id++
|
|
allObjs = append(allObjs, tmp)
|
newObjs = append(newObjs, tmp)
|
}
|
}
|
*LastTrackID = id
|
return allObjs, newObjs
|
}
|
|
// Run yolo detect (只识别人)
|
func Run(i interface{}, id string, data []byte, w, h, c int, thrsh float32, umns int) ([]byte, int) {
|
if data == nil || w <= 0 || h <= 0 {
|
return nil, 0
|
}
|
y := i.(*YoloHandle)
|
|
channel := c
|
if channel == 0 {
|
channel = 3
|
}
|
|
v, ok := y.tracker[id]
|
if !ok {
|
i := &trackInfo{}
|
y.tracker[id] = i
|
v = i
|
}
|
whole, _ := YoloDetectTrack2(y, v.lastTrackObjs, &v.lastTrackID, data, w, h, channel, thrsh, umns)
|
y.tracker[id].lastTrackObjs = whole
|
y.tracker[id].lastTrackID = v.lastTrackID
|
|
var dWhole []byte
|
var err error
|
if len(whole) > 0 {
|
|
infos := convert2ProtoYoloTrack(whole, 1.0, 1.0)
|
p := protomsg.ParamYoloObj{Infos: infos}
|
|
dWhole, err = proto.Marshal(&p)
|
if err != nil {
|
fmt.Println("ydetect track marshal proto yolo obj error", err)
|
dWhole = nil
|
}
|
}
|
|
return dWhole, len(whole)
|
}
|
|
func convert2ProtoYoloTrack(obj []CObjTrackInfo, fx, fy float64) []*protomsg.ObjInfo {
|
ret := []*protomsg.ObjInfo{}
|
|
for _, v := range obj {
|
if fx < 1.0 || fy < 1.0 {
|
v.ObjInfo.RcObj.Left = (int32)((float64)(v.ObjInfo.RcObj.Left) / fx)
|
v.ObjInfo.RcObj.Right = (int32)((float64)(v.ObjInfo.RcObj.Right) / fx)
|
v.ObjInfo.RcObj.Top = (int32)((float64)(v.ObjInfo.RcObj.Top) / fy)
|
v.ObjInfo.RcObj.Bottom = (int32)((float64)(v.ObjInfo.RcObj.Bottom) / fy)
|
}
|
|
rect := protomsg.Rect{
|
Left: v.ObjInfo.RcObj.Left,
|
Right: v.ObjInfo.RcObj.Right,
|
Top: v.ObjInfo.RcObj.Top,
|
Bottom: v.ObjInfo.RcObj.Bottom,
|
}
|
obj := protomsg.ObjInfo{
|
RcObj: &rect,
|
Typ: v.ObjInfo.Typ,
|
Prob: v.ObjInfo.Prob,
|
ObjID: v.ID,
|
}
|
|
ret = append(ret, &obj)
|
}
|
return ret
|
}
|
|
// Run2 yolo detect (只识别人)
|
func Run2(i interface{}, id string, data []byte, w, h, c int, thrsh float32, umns int) ([]byte, int, []byte, int) {
|
if data == nil || w <= 0 || h <= 0 {
|
return nil, 0, nil, 0
|
}
|
y := i.(*YoloHandle)
|
|
channel := c
|
if channel == 0 {
|
channel = 3
|
}
|
|
v, ok := y.tracker[id]
|
if !ok {
|
i := &trackInfo{}
|
y.tracker[id] = i
|
v = i
|
}
|
whole, recent := YoloDetectTrack2(y, v.lastTrackObjs, &v.lastTrackID, data, w, h, channel, thrsh, umns)
|
y.tracker[id].lastTrackObjs = whole
|
y.tracker[id].lastTrackID = v.lastTrackID
|
|
var dWhole, dRecent []byte
|
var err error
|
if len(whole) > 0 {
|
|
infos := convert2ProtoYoloTrack(whole, 1.0, 1.0)
|
p := protomsg.ParamYoloObj{Infos: infos}
|
|
dWhole, err = proto.Marshal(&p)
|
if err != nil {
|
fmt.Println("ydetect track marshal proto yolo obj error", err)
|
dWhole = nil
|
}
|
}
|
if len(recent) > 0 {
|
dRecent = nil
|
}
|
return dWhole, len(whole), dRecent, len(recent)
|
}
|