import logging
|
import time
|
from concurrent.futures import ThreadPoolExecutor
|
import threading
|
import torch
|
from PIL import Image
|
from pymilvus import connections, Collection
|
from datetime import datetime
|
import requests
|
import asyncio
|
import re
|
|
from qwen_vl_utils import process_vision_info
|
from transformers import AutoModelForVision2Seq, AutoProcessor
|
|
|
class qwen_thread:
|
def __init__(self, config,logger):
|
self.config = config
|
self.max_workers = int(config.get("threadnum"))
|
self.executor = ThreadPoolExecutor(max_workers=int(config.get("threadnum")))
|
self.semaphore = threading.Semaphore(int(config.get("threadnum")))
|
self.logger = logger
|
|
# 初始化Milvus集合
|
connections.connect("default", host=config.get("milvusurl"), port=config.get("milvusport"))
|
# 加载集合
|
self.collection = Collection(name="smartobject")
|
self.collection.load()
|
if config.get('cuda') is None or config.get('cuda') == '0':
|
self.device = f"cuda"
|
else:
|
self.device = f"cuda:{config.get('cuda')}"
|
self.model_pool = []
|
self.lock_pool = [threading.Lock() for _ in range(int(config.get("qwenwarning")))]
|
for i in range(int(config.get("qwenwarning"))):
|
model = AutoModelForVision2Seq.from_pretrained(
|
config.get("qwenaddr"),
|
device_map=self.device,
|
trust_remote_code=True,
|
use_safetensors=True,
|
torch_dtype=torch.float16
|
|
).eval()
|
model = model.to(self.device)
|
self.model_pool.append(model)
|
|
# 共享的处理器 (线程安全)
|
self.processor = AutoProcessor.from_pretrained(config.get("qwenaddr"), use_fast=True)
|
|
|
def submit(self,res_a):
|
# 尝试获取信号量(非阻塞)
|
acquired = self.semaphore.acquire(blocking=False)
|
|
if not acquired:
|
#self.logger.info(f"线程池已满,等待空闲线程... (当前活跃: {self.max_workers - self.semaphore._value}/{self.max_workers})")
|
# 阻塞等待直到有可用线程
|
self.semaphore.acquire(blocking=True)
|
|
future = self.executor.submit(self._wrap_task, res_a)
|
future.add_done_callback(self._release_semaphore)
|
return future
|
|
def _wrap_task(self, res_a):
|
try:
|
self.tark_do(res_a, self.config.get("ragurl"), self.config.get("ragmode"), self.config.get("max_tokens"))
|
except Exception as e:
|
self.logger.info(f"处理出错: {e}")
|
raise
|
|
def tark_do(self,res,ragurl,rag_mode,max_tokens):
|
try:
|
# 生成图片描述
|
ks_time = datetime.now()
|
|
risk_description = ""
|
suggestion = ""
|
# 调用规则匹配方法,判断是否预警
|
is_waning = self.image_rule(res)
|
self.logger.info(f"预警规则规则规则is_waning:{is_waning}")
|
#更新数据的预警结果与数据预警状态
|
data = {
|
"event_level_id": res['event_level_id'], # event_level_id
|
"event_level_name": res['event_level_name'], # event_level_id
|
"rule_id": res["rule_id"],
|
"video_point_id": res['video_point_id'], # video_point_id
|
"video_point_name": res['video_point_name'],
|
"is_waning": is_waning,
|
"is_desc": 5, #改为已经预警
|
"zh_desc_class": res['zh_desc_class'], # text_vector
|
"bounding_box": res['bounding_box'], # bounding_box
|
"task_id": res['task_id'], # task_id
|
"task_name": res['task_name'], # task_id
|
"detect_id": res['detect_id'], # detect_id
|
"detect_time": res['detect_time'], # detect_time
|
"detect_num": res['detect_num'],
|
"waning_value": res['waning_value'],
|
"image_path": res['image_path'], # image_path
|
"image_desc_path": res['image_desc_path'], # image_desc_path
|
"video_path": res['video_path'],
|
"text_vector": res['text_vector'],
|
"risk_description": risk_description,
|
"suggestion": suggestion,
|
"knowledge_id": res['knowledge_id']
|
}
|
self.collection.delete(f"id == {res['id']}")
|
# 保存到milvus
|
image_id = self.collection.insert(data).primary_keys
|
res['id'] = image_id[0]
|
self.logger.info(f"{res['video_point_id']}预警执行完毕:{image_id}运行结束总体用时:{datetime.now() - ks_time}")
|
return None
|
except Exception as e:
|
self.logger.info(f"线程:执行模型解析时出错::{e}")
|
return 0
|
|
def image_rule(self, res_data):
|
self.logger.info(f"预警规则规则规则等级分类就是打裂缝多少积分")
|
try:
|
model, lock = self._acquire_model()
|
image = Image.open(res_data['image_desc_path']).convert("RGB").resize((600, 600), Image.Resampling.LANCZOS)
|
|
messages = [
|
{
|
"role": "user",
|
"content": [
|
{"type": "image", "image": image},
|
{"type": "text", "text": f"请检测图片中是否{res_data['waning_value']}?请回答yes或no"},
|
],
|
}
|
]
|
|
# Preparation for inference
|
text = self.processor.apply_chat_template(
|
messages, tokenize=False, add_generation_prompt=True
|
)
|
image_inputs, video_inputs = process_vision_info(messages)
|
inputs = self.processor(
|
text=[text],
|
images=image_inputs,
|
videos=video_inputs,
|
padding=True,
|
return_tensors="pt",
|
)
|
inputs = inputs.to(model.device)
|
|
with torch.no_grad():
|
outputs = model.generate(**inputs, max_new_tokens=10)
|
generated_ids = outputs[:, len(inputs.input_ids[0]):]
|
image_text = self.processor.batch_decode(
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
)
|
|
image_des = (image_text[0]).strip()
|
upper_text = image_des.upper()
|
self.logger.info(f"预警规则规则规则:{upper_text}")
|
if "YES" in upper_text:
|
return 1
|
else:
|
return 0
|
except Exception as e:
|
self.logger.info(f"线程:执行图片描述时出错:{e}")
|
return 0
|
finally:
|
# 4. 释放模型
|
self._release_model(model)
|
torch.cuda.empty_cache()
|
def _release_semaphore(self, future):
|
self.semaphore.release()
|
#self.logger.info(f"释放线程 (剩余空闲: {self.semaphore._value}/{self.max_workers})")
|
|
def shutdown(self):
|
"""安全关闭"""
|
self.executor.shutdown(wait=False)
|
for model in self.model_pool:
|
del model
|
torch.cuda.empty_cache()
|
|
def _acquire_model(self):
|
"""从池中获取一个空闲模型 (简单轮询)"""
|
while True:
|
for i, (model, lock) in enumerate(zip(self.model_pool, self.lock_pool)):
|
if lock.acquire(blocking=False):
|
return model, lock
|
time.sleep(0.1) # 避免CPU空转
|
|
def _release_model(self, model):
|
"""释放模型回池"""
|
for i, m in enumerate(self.model_pool):
|
if m == model:
|
self.lock_pool[i].release()
|
break
|