import time
|
from concurrent.futures import ThreadPoolExecutor
|
import threading
|
|
import torch
|
from PIL import Image
|
from pymilvus import connections, Collection
|
from datetime import datetime
|
import os
|
import requests
|
import asyncio
|
import logging
|
import re
|
from logging.handlers import RotatingFileHandler
|
|
from transformers import AutoModelForVision2Seq, AutoProcessor
|
|
|
class qwen_thread:
|
def __init__(self, max_workers,config,model_path):
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
self.semaphore = threading.Semaphore(max_workers)
|
self.max_workers = max_workers
|
# 初始化Milvus集合
|
connections.connect("default", host=config.get("milvusurl"), port=config.get("milvusport"))
|
# 加载集合
|
self.collection = Collection(name="smartobject")
|
self.collection.load()
|
|
self.config = config
|
self.model_pool = []
|
self.lock_pool = [threading.Lock() for _ in range(max_workers)]
|
for i in range(max_workers):
|
model = AutoModelForVision2Seq.from_pretrained(
|
model_path,
|
device_map="cuda:1",
|
trust_remote_code=True,
|
torch_dtype=torch.float16
|
).eval()
|
self.model_pool.append(model)
|
|
# 共享的处理器 (线程安全)
|
self.processor = AutoProcessor.from_pretrained(model_path,use_fast=True)
|
|
|
# 创建实例专属logger
|
self.logger = logging.getLogger(f"{self.__class__}_{id(self)}")
|
self.logger.setLevel(logging.INFO)
|
# 避免重复添加handler
|
if not self.logger.handlers:
|
handler = RotatingFileHandler(
|
filename=os.path.join("logs", 'thread_log.log'),
|
maxBytes=10 * 1024 * 1024,
|
backupCount=3,
|
encoding='utf-8'
|
)
|
formatter = logging.Formatter(
|
'%(asctime)s - %(filename)s:%(lineno)d - %(funcName)s() - %(levelname)s: %(message)s'
|
)
|
handler.setFormatter(formatter)
|
self.logger.addHandler(handler)
|
|
def submit(self,res_a):
|
# 尝试获取信号量(非阻塞)
|
acquired = self.semaphore.acquire(blocking=False)
|
|
if not acquired:
|
self.logger.info(f"线程池已满,等待空闲线程... (当前活跃: {self.max_workers - self.semaphore._value}/{self.max_workers})")
|
# 阻塞等待直到有可用线程
|
self.semaphore.acquire(blocking=True)
|
|
future = self.executor.submit(self._wrap_task, res_a)
|
future.add_done_callback(self._release_semaphore)
|
return future
|
|
def _wrap_task(self, res):
|
try:
|
#self.logger.info(f"处理: { res['id']}开始")
|
current_time = datetime.now()
|
image_id = self.tark_do(res, self.config.get("ragurl"), self.config.get("ragmode"), self.config.get("max_tokens"))
|
self.logger.info(f"处理: { res['id']}完毕{image_id}:{datetime.now() - current_time}")
|
return image_id
|
except Exception as e:
|
self.logger.info(f"任务 { res['id']} 处理出错: {e}")
|
raise
|
|
def tark_do(self,res,ragurl,rag_mode,max_tokens):
|
try :
|
# 1. 从集合A获取向量和元数据
|
is_waning = 0
|
image_des = self.image_desc(f"{res['image_desc_path']}")
|
self.logger.info(image_des)
|
if image_des:
|
# rule_text = self.get_rule(ragurl)
|
is_waning = self.image_rule_chat(image_des,res['waning_value'],ragurl,rag_mode,max_tokens)
|
is_desc = 2
|
else:
|
is_waning = 0
|
is_desc = 3
|
data = {
|
"id": res['id'],
|
"event_level_id": res['event_level_id'], # event_level_id
|
"event_level_name": res['event_level_name'], # event_level_id
|
"rule_id": res["rule_id"],
|
"video_point_id": res['video_point_id'], # video_point_id
|
"video_point_name": res['video_point_name'],
|
"is_waning": is_waning,
|
"is_desc": is_desc,
|
"zh_desc_class": image_des, # text_vector
|
"bounding_box": res['bounding_box'], # bounding_box
|
"task_id": res['task_id'], # task_id
|
"task_name": res['task_name'], # task_id
|
"detect_id": res['detect_id'], # detect_id
|
"detect_time": res['detect_time'], # detect_time
|
"detect_num": res['detect_num'],
|
"waning_value": res['waning_value'],
|
"image_path": res['image_path'], # image_path
|
"image_desc_path": res['image_desc_path'], # image_desc_path
|
"video_path": res['video_path'],
|
"text_vector": res['text_vector']
|
}
|
# 保存到milvus
|
image_id = self.collection.upsert(data).primary_keys
|
if is_desc == 2:
|
data = {
|
"id": str(image_id[0]),
|
"video_point_id": res['video_point_id'],
|
"video_path": res["video_point_name"],
|
"zh_desc_class": image_des,
|
"detect_time": res['detect_time'],
|
"image_path": f"{res['image_path']}",
|
"task_name": res["task_name"],
|
"event_level_name": res["event_level_name"],
|
"rtsp_address": f"{res['video_path']}"
|
}
|
# 调用rag
|
asyncio.run(self.insert_json_data(ragurl, data))
|
return image_id
|
except Exception as e:
|
self.logger.info(f"线程:执行模型解析时出错:任务:{e}")
|
return 0
|
|
def image_desc(self, image_path):
|
try:
|
model, lock = self._acquire_model()
|
# 2. 处理图像
|
image = Image.open(image_path).convert("RGB") # 替换为您的图片路径
|
image = image.resize((600, 600), Image.Resampling.LANCZOS) # 高质量缩放
|
messages = [
|
{
|
"role": "user",
|
"content": [
|
{
|
"type": "image",
|
},
|
{"type": "text", "text": "请详细描述图片中的目标信息及特征。返回格式为整段文字描述"},
|
],
|
}
|
]
|
# Preparation for inference
|
text = self.processor.apply_chat_template(
|
messages, add_generation_prompt=True
|
)
|
inputs = self.processor(
|
text=[text],
|
images=[image],
|
padding=True,
|
return_tensors="pt",
|
)
|
inputs = inputs.to("cuda:1")
|
current_time = datetime.now()
|
outputs = model.generate(**inputs,
|
max_new_tokens=300,
|
do_sample=True,
|
temperature=0.7,
|
renormalize_logits=True
|
)
|
print(f"处理完毕:{datetime.now() - current_time}")
|
generated_ids = outputs[:, len(inputs.input_ids[0]):]
|
image_text = self.processor.batch_decode(
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
)
|
image_des = (image_text[0]).strip()
|
return image_des
|
except Exception as e:
|
self.logger.info(f"线程:执行图片描述时出错:{e}")
|
finally:
|
# 4. 释放模型
|
self._release_model(model)
|
torch.cuda.empty_cache()
|
|
def get_rule(self,ragurl):
|
try:
|
rule_text = None
|
search_data = {
|
"collection_name": "smart_rule",
|
"query_text": "",
|
"search_mode": "hybrid",
|
"limit": 100,
|
"weight_dense": 0.7,
|
"weight_sparse": 0.3,
|
"filter_expr": "",
|
"output_fields": ["text"]
|
}
|
response = requests.post(ragurl + "/search", json=search_data)
|
results = response.json().get('results')
|
rule_text = ""
|
ruleid = 1
|
for rule in results:
|
if rule['score'] >= 0:
|
rule_text = rule_text + str(ruleid) + ". " + rule['entity'].get('text') + ";\n"
|
ruleid = ruleid + 1
|
# self.logger.info(len(rule_text))
|
else:
|
self.logger.info(f"线程:执行获取规则时出错:{response}")
|
return rule_text
|
except Exception as e:
|
self.logger.info(f"线程:执行获取规则时出错:{e}")
|
return None
|
|
def image_rule_chat(self, image_des,rule_text, ragurl, rag_mode,max_tokens):
|
try:
|
content = (
|
f"图片描述内容为:\n{image_des}\n规则内容:\n{rule_text}。\n请验证图片描述中是否有符合规则的内容,不进行推理和think。返回结果格式为[xxx符合的规则id],如果没有返回[]")
|
# self.logger.info(content)
|
#self.logger.info(len(content))
|
search_data = {
|
"prompt": "",
|
"messages": [
|
{
|
"role": "user",
|
"content": content
|
}
|
],
|
"llm_name": rag_mode,
|
"stream": False,
|
"gen_conf": {
|
"temperature": 0.7,
|
"max_tokens": max_tokens
|
}
|
}
|
response = requests.post(ragurl + "/chat", json=search_data)
|
results = response.json().get('data')
|
#self.logger.info(len(results))
|
# self.logger.info(results)
|
ret = re.sub(r'<think>.*?</think>', '', results, flags=re.DOTALL)
|
ret = ret.replace(" ", "").replace("\t", "").replace("\n", "")
|
is_waning = 0
|
if len(ret) > 2:
|
is_waning = 1
|
return is_waning
|
except Exception as e:
|
self.logger.info(f"线程:执行规则匹配时出错:{image_des, rule_text, ragurl, rag_mode,e}")
|
return None
|
|
async def insert_json_data(self, ragurl, data):
|
try:
|
data = {'collection_name': "smartrag", "data": data, "description": ""}
|
requests.post(ragurl + "/insert_json_data", json=data, timeout=(0.3, 0.3))
|
#self.logger.info(f"调用录像服务:{ragurl, data}")
|
except Exception as e:
|
#self.logger.info(f"{self._thread_name}线程:调用录像时出错:地址:{ragurl}:{e}")
|
return
|
|
def _release_semaphore(self, future):
|
self.semaphore.release()
|
self.logger.info(f"释放线程 (剩余空闲: {self.semaphore._value}/{self.max_workers})")
|
|
def shutdown(self):
|
"""安全关闭"""
|
self.executor.shutdown(wait=False)
|
for model in self.model_pool:
|
del model
|
torch.cuda.empty_cache()
|
|
def _acquire_model(self):
|
"""从池中获取一个空闲模型 (简单轮询)"""
|
while True:
|
for i, (model, lock) in enumerate(zip(self.model_pool, self.lock_pool)):
|
if lock.acquire(blocking=False):
|
return model, lock
|
time.sleep(0.1) # 避免CPU空转
|
|
def _release_model(self, model):
|
"""释放模型回池"""
|
for i, m in enumerate(self.model_pool):
|
if m == model:
|
self.lock_pool[i].release()
|
break
|
|
|
def remove_duplicate_lines(self,text):
|
seen = set()
|
result = []
|
for line in text.split('。'): # 按句号分割
|
if line.strip() and line not in seen:
|
seen.add(line)
|
result.append(line)
|
return '。'.join(result)
|
def remove_duplicate_lines_d(self,text):
|
seen = set()
|
result = []
|
for line in text.split(','): # 按句号分割
|
if line.strip() and line not in seen:
|
seen.add(line)
|
result.append(line)
|
return '。'.join(result)
|
def remove_duplicate_lines_n(self,text):
|
seen = set()
|
result = []
|
for line in text.split('\n'): # 按句号分割
|
if line.strip() and line not in seen:
|
seen.add(line)
|
result.append(line)
|
return '。'.join(result)
|