#2025/7/3
#新增qwen_detect_batch.py 用于生成批量图片的启动程序,目前是测试版本
#新增qwen_thread_batch.py 用于生成批量图片的多线程处理,目前是测试版本
| New file |
| | |
| | | from operator import itemgetter |
| | | import threading |
| | | import time as time_sel |
| | | from typing import Dict |
| | | from qwen_thread_batch import qwen_thread_batch |
| | | import requests |
| | | import os |
| | | import logging |
| | | from pymilvus import connections, Collection |
| | | from logging.handlers import RotatingFileHandler |
| | | import get_mem |
| | | |
| | | class ThreadPool: |
| | | def __init__(self): |
| | | #读取配置文件 |
| | | self.config = {} |
| | | with open('./conf.txt', 'r', encoding='utf-8') as file: |
| | | for line in file: |
| | | # 去除每行的首尾空白字符(包括换行符) |
| | | line = line.strip() |
| | | # 跳过空行 |
| | | if not line: |
| | | continue |
| | | # 分割键和值 |
| | | if '=' in line: |
| | | key, value = line.split('=', 1) |
| | | # 去除键和值的首尾空白字符 |
| | | key = key.strip() |
| | | value = value.strip() |
| | | # 将键值对添加到字典中 |
| | | self.config[key] = value |
| | | # 配置日志 |
| | | # 确保日志目录存在 |
| | | log_dir = "logs" |
| | | os.makedirs(log_dir, exist_ok=True) |
| | | self.threads: Dict[str, threading.Thread] = {} |
| | | self.lock = threading.Lock() |
| | | |
| | | # 初始化Milvus集合 |
| | | connections.connect("default", host=self.config.get("milvusurl"), port=self.config.get("milvusport")) |
| | | # 加载集合 |
| | | self.collection = Collection(name="smartobject") |
| | | self.collection.load() |
| | | self.pool = qwen_thread_batch(int(self.config.get("threadnum")), self.config,"/home/debian/Qwen2.5-VL-3B-Instruct-GPTQ-Int4") |
| | | #是否更新 |
| | | self._isupdate = False |
| | | |
| | | # 初始化共享内存 |
| | | get_mem.smem_init() |
| | | |
| | | # 配置日志 |
| | | logging.basicConfig( |
| | | level=logging.INFO, |
| | | format='%(asctime)s - %(filename)s:%(lineno)d - %(funcName)s() - %(levelname)s: %(message)s', |
| | | datefmt='%Y-%m-%d %H:%M:%S', |
| | | handlers=[ |
| | | # 按大小轮转的日志文件(最大10MB,保留3个备份) |
| | | RotatingFileHandler( |
| | | filename=os.path.join(log_dir, 'start_log.log'), |
| | | maxBytes=10 * 1024 * 1024, # 10MB |
| | | backupCount=3, |
| | | encoding='utf-8' |
| | | ), |
| | | # 同时输出到控制台 |
| | | logging.StreamHandler() |
| | | ] |
| | | ) |
| | | |
| | | #启动线程 |
| | | def safe_start(self, target_func, camera_id): |
| | | """线程安全启动方法""" |
| | | def wrapped(): |
| | | thread_name = threading.current_thread().name |
| | | try: |
| | | target_func(camera_id) |
| | | except Exception as e: |
| | | logging.error(f"线程异常: {str(e)}", exc_info=True) |
| | | |
| | | with self.lock: # 确保线程安全创建 |
| | | t = threading.Thread( |
| | | target=wrapped, |
| | | daemon=True # 设置为守护线程 |
| | | ) |
| | | t.start() |
| | | self.threads[camera_id] = t |
| | | return t |
| | | |
| | | # 启动线程任务 |
| | | def worker(self, camera_id): |
| | | while True: |
| | | try: |
| | | res_a = self.collection.query( |
| | | expr=f"is_desc == 0 and video_point_id=={camera_id}", |
| | | output_fields=["id", "zh_desc_class", "text_vector", "bounding_box", "video_point_name", "task_id", |
| | | "task_name", "event_level_id", "event_level_name", |
| | | "video_point_id", "detect_num", "is_waning", "is_desc", "waning_value", "rule_id", |
| | | "detect_id", |
| | | "detect_time", "image_path", "image_desc_path", "video_path"], |
| | | consistency_level="Strong", |
| | | order_by_field="id", # 按id字段排序 |
| | | order_by_type="desc" # 降序排列 |
| | | ) |
| | | |
| | | # 读取共享内存中的图片 |
| | | # image_id = get_mem.smem_read_frame_qianwen(camera_id) |
| | | if len(res_a) > 0: |
| | | sorted_results = sorted(res_a, key=itemgetter("id"), reverse=True) |
| | | # 查询前N个最大的ID |
| | | res_a = sorted_results[:int(self.config.get("detectnum"))-1] |
| | | res_data = [] |
| | | for res in res_a: |
| | | data = { |
| | | "id": res['id'], |
| | | "event_level_id": res['event_level_id'], # event_level_id |
| | | "event_level_name": res['event_level_name'], # event_level_id |
| | | "rule_id": res["rule_id"], |
| | | "video_point_id": res['video_point_id'], # video_point_id |
| | | "video_point_name": res['video_point_name'], |
| | | "is_waning": 0, |
| | | "is_desc": 1, |
| | | "zh_desc_class": res['zh_desc_class'], # text_vector |
| | | "bounding_box": res['bounding_box'], # bounding_box |
| | | "task_id": res['task_id'], # task_id |
| | | "task_name": res['task_name'], # task_id |
| | | "detect_id": res['detect_id'], # detect_id |
| | | "detect_time": res['detect_time'], # detect_time |
| | | "detect_num": res['detect_num'], |
| | | "waning_value": res['waning_value'], |
| | | "image_path": res['image_path'], # image_path |
| | | "image_desc_path": res['image_desc_path'], # image_desc_path |
| | | "video_path": res['video_path'], |
| | | "text_vector": res['text_vector'] |
| | | } |
| | | # logging.info(f"读取图像成功: {res['id']}") |
| | | # 保存到milvus |
| | | image_id = self.collection.upsert(data).primary_keys |
| | | res['id'] = image_id[0] |
| | | res_data.append(res) |
| | | # logging.info(f"读取图像成功: {image_id}") |
| | | self.pool.submit(res_data) |
| | | # image_id = pool.tark_do(image_id,self.config.get("ragurl"),self.config.get("ragmode"),self.config.get("max_tokens")) |
| | | # logging.info(f"处理图像成功: {image_id}") |
| | | sorted_results = None |
| | | except Exception as e: |
| | | logging.info(f"{camera_id}线程错误:{e}") |
| | | |
| | | #调用是否需要更新 |
| | | def isUpdate(self): |
| | | try: |
| | | # 定义请求的 URL |
| | | url = self.config.get("isupdateurl") |
| | | # 发送 GET 请求 |
| | | response = requests.get(url) |
| | | |
| | | # 检查响应状态码 |
| | | if response.status_code == 200: |
| | | data = response.json().get("data") |
| | | if data.get("isChange") == 1: |
| | | return True |
| | | else: |
| | | return False |
| | | except Exception as e: |
| | | logging.info(f"调用是否需要更新时出错:URL:{self.config.get('isupdateurl')}:{e}") |
| | | return False |
| | | |
| | | #修改是否更新状态 |
| | | def update_status(self): |
| | | try: |
| | | # 更新状态 |
| | | url = self.config.get("updatestatusurl") |
| | | # 发送 GET 请求 |
| | | response = requests.post(url) |
| | | # 检查响应状态码 |
| | | if response.status_code == 200: |
| | | return True |
| | | else: |
| | | return False |
| | | except Exception as e: |
| | | logging.info(f"修改是否更新状态时出错:URL:{self.config.get('updatestatusurl')}:{e}") |
| | | return False |
| | | |
| | | def shutdown_all(self) -> None: |
| | | """清理所有线程""" |
| | | with self.lock: |
| | | for camera_id, thread in list(self.threads.items()): |
| | | if thread.is_alive(): |
| | | thread.join(timeout=1) |
| | | del self.threads[camera_id] |
| | | |
| | | #获取任务 |
| | | def getTaskconf(self,isupdate): |
| | | try: |
| | | # 定义请求的 URL |
| | | url = self.config.get("gettaskconfurl") |
| | | # 发送 GET 请求 |
| | | response = requests.get(url) |
| | | # 检查响应状态码 |
| | | if response.status_code == 200: |
| | | data = response.json() |
| | | if isupdate: |
| | | # 更新状态 |
| | | self.update_status() |
| | | return data.get("data") |
| | | else: |
| | | return [] |
| | | except Exception as e: |
| | | logging.info(f"调用获取任务时出错:URL:{self.config.get('gettaskconfurl')}:{e}") |
| | | return [] |
| | | |
| | | # 使用示例 |
| | | if __name__ == "__main__": |
| | | pool = ThreadPool() |
| | | is_init = True |
| | | camera_data = pool.getTaskconf(False) |
| | | while True: |
| | | try: |
| | | pool._isupdate = False # 是否更新数据 |
| | | # 是否需要更新任务数据 |
| | | if pool.isUpdate(): |
| | | # 获取摄像机任务 |
| | | camera_data = pool.getTaskconf(True) |
| | | pool._isupdate = True # 更新数据 |
| | | |
| | | if is_init: |
| | | if camera_data: |
| | | for camera in camera_data: |
| | | thread = pool.threads.get(camera.get("camera_id")) |
| | | if not thread: |
| | | logging.info(f"开始创建{camera.get('camera_id')}线程") |
| | | pool.safe_start(pool.worker, camera.get('camera_id')) |
| | | logging.info(f"{camera.get('camera_id')}线程创建完毕") |
| | | |
| | | if pool._isupdate: |
| | | logging.info(f"更新线程开始") |
| | | pool.shutdown_all() |
| | | if camera_data: |
| | | for camera in camera_data: |
| | | thread = pool.threads.get(camera.get("camera_id")) |
| | | if not thread: |
| | | logging.info(f"开始创建{camera.get('camera_id')}线程") |
| | | pool.safe_start(pool.worker, camera.get('camera_id')) |
| | | logging.info(f"{camera.get('camera_id')}线程创建完毕") |
| | | |
| | | logging.info(f"更新线程结束") |
| | | |
| | | is_init = False |
| | | time_sel.sleep(1) |
| | | except Exception as e: |
| | | logging.info(f"主线程未知错误:{e}") |
| New file |
| | |
| | | import time |
| | | from concurrent.futures import ThreadPoolExecutor |
| | | import threading |
| | | |
| | | import torch |
| | | from PIL import Image |
| | | from pymilvus import connections, Collection |
| | | from datetime import datetime |
| | | import os |
| | | import requests |
| | | import asyncio |
| | | import logging |
| | | import re |
| | | from logging.handlers import RotatingFileHandler |
| | | |
| | | from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig |
| | | |
| | | |
| | | class qwen_thread_batch: |
| | | def __init__(self, max_workers,config,model_path): |
| | | self.executor = ThreadPoolExecutor(max_workers=max_workers) |
| | | self.semaphore = threading.Semaphore(max_workers) |
| | | self.max_workers = max_workers |
| | | # 初始化Milvus集合 |
| | | connections.connect("default", host=config.get("milvusurl"), port=config.get("milvusport")) |
| | | # 加载集合 |
| | | self.collection = Collection(name="smartobject") |
| | | self.collection.load() |
| | | |
| | | self.config = config |
| | | self.model_pool = [] |
| | | self.lock_pool = [threading.Lock() for _ in range(max_workers)] |
| | | quant_config = BitsAndBytesConfig( |
| | | load_in_4bit=True, |
| | | bnb_4bit_compute_dtype=torch.float16, |
| | | bnb_4bit_quant_type="nf4", |
| | | bnb_4bit_use_double_quant=True |
| | | ) |
| | | for i in range(max_workers): |
| | | model = AutoModelForVision2Seq.from_pretrained( |
| | | model_path, |
| | | device_map="cuda:1", |
| | | trust_remote_code=True, |
| | | quantization_config=quant_config, |
| | | use_flash_attention_2=True, |
| | | |
| | | ).eval() |
| | | self.model_pool.append(model) |
| | | |
| | | # 共享的处理器 (线程安全) |
| | | self.processor = AutoProcessor.from_pretrained(model_path,use_fast=True) |
| | | |
| | | |
| | | # 创建实例专属logger |
| | | self.logger = logging.getLogger(f"{self.__class__}_{id(self)}") |
| | | self.logger.setLevel(logging.INFO) |
| | | # 避免重复添加handler |
| | | if not self.logger.handlers: |
| | | handler = RotatingFileHandler( |
| | | filename=os.path.join("logs", 'thread_log.log'), |
| | | maxBytes=10 * 1024 * 1024, |
| | | backupCount=3, |
| | | encoding='utf-8' |
| | | ) |
| | | formatter = logging.Formatter( |
| | | '%(asctime)s - %(filename)s:%(lineno)d - %(funcName)s() - %(levelname)s: %(message)s' |
| | | ) |
| | | handler.setFormatter(formatter) |
| | | self.logger.addHandler(handler) |
| | | |
| | | def submit(self,res_a): |
| | | # 尝试获取信号量(非阻塞) |
| | | acquired = self.semaphore.acquire(blocking=False) |
| | | |
| | | if not acquired: |
| | | #self.logger.info(f"线程池已满,等待空闲线程... (当前活跃: {self.max_workers - self.semaphore._value}/{self.max_workers})") |
| | | # 阻塞等待直到有可用线程 |
| | | self.semaphore.acquire(blocking=True) |
| | | |
| | | future = self.executor.submit(self._wrap_task, res_a) |
| | | future.add_done_callback(self._release_semaphore) |
| | | return future |
| | | |
| | | def _wrap_task(self, res_a): |
| | | try: |
| | | self.tark_do(res_a, self.config.get("ragurl"), self.config.get("ragmode"), self.config.get("max_tokens")) |
| | | except Exception as e: |
| | | self.logger.info(f"处理出错: {e}") |
| | | raise |
| | | |
| | | def tark_do(self,res_a,ragurl,rag_mode,max_tokens): |
| | | try : |
| | | current_time = datetime.now() |
| | | # 1. 从集合A获取向量和元数据 |
| | | is_waning = 0 |
| | | desc_list = self.image_desc(res_a) |
| | | if desc_list: |
| | | for desc,res in zip(desc_list,res_a): |
| | | if desc: |
| | | # rule_text = self.get_rule(ragurl) |
| | | is_waning = self.image_rule_chat(desc,res['waning_value'],ragurl,rag_mode,max_tokens) |
| | | is_desc = 2 |
| | | else: |
| | | is_waning = 0 |
| | | is_desc = 3 |
| | | data = { |
| | | "id": res['id'], |
| | | "event_level_id": res['event_level_id'], # event_level_id |
| | | "event_level_name": res['event_level_name'], # event_level_id |
| | | "rule_id": res["rule_id"], |
| | | "video_point_id": res['video_point_id'], # video_point_id |
| | | "video_point_name": res['video_point_name'], |
| | | "is_waning": is_waning, |
| | | "is_desc": is_desc, |
| | | "zh_desc_class": desc, # text_vector |
| | | "bounding_box": res['bounding_box'], # bounding_box |
| | | "task_id": res['task_id'], # task_id |
| | | "task_name": res['task_name'], # task_id |
| | | "detect_id": res['detect_id'], # detect_id |
| | | "detect_time": res['detect_time'], # detect_time |
| | | "detect_num": res['detect_num'], |
| | | "waning_value": res['waning_value'], |
| | | "image_path": res['image_path'], # image_path |
| | | "image_desc_path": res['image_desc_path'], # image_desc_path |
| | | "video_path": res['video_path'], |
| | | "text_vector": res['text_vector'] |
| | | } |
| | | # 保存到milvus |
| | | image_id = self.collection.upsert(data).primary_keys |
| | | #self.logger.info(f"{res['id']}--{image_id}:{desc}") |
| | | if is_desc == 2: |
| | | data = { |
| | | "id": str(image_id[0]), |
| | | "video_point_id": res['video_point_id'], |
| | | "video_path": res["video_point_name"], |
| | | "zh_desc_class": desc, |
| | | "detect_time": res['detect_time'], |
| | | "image_path": f"{res['image_path']}", |
| | | "task_name": res["task_name"], |
| | | "event_level_name": res["event_level_name"], |
| | | "rtsp_address": f"{res['video_path']}" |
| | | } |
| | | # 调用rag |
| | | asyncio.run(self.insert_json_data(ragurl, data)) |
| | | self.logger.info(f"处理完毕:{datetime.now() - current_time}:{len(res_a)}") |
| | | except Exception as e: |
| | | self.logger.info(f"线程:执行模型解析时出错:任务:{e}") |
| | | return 0 |
| | | |
| | | def image_desc(self, res_data): |
| | | try: |
| | | model, lock = self._acquire_model() |
| | | image_data = [] |
| | | for res in res_data: |
| | | # 2. 处理图像 |
| | | image = Image.open(f"{res['image_desc_path']}").convert("RGB") # 替换为您的图片路径 |
| | | image = image.resize((448, 448), Image.Resampling.LANCZOS) # 高质量缩放 |
| | | image_data.append(image) |
| | | |
| | | messages = [ |
| | | { |
| | | "role": "user", |
| | | "content": [ |
| | | { |
| | | "type": "image", |
| | | }, |
| | | {"type": "text", "text": "请详细描述图片中的目标信息及特征。返回格式为整段文字描述"}, |
| | | ], |
| | | } |
| | | ] |
| | | # Preparation for inference |
| | | text = self.processor.apply_chat_template( |
| | | messages, add_generation_prompt=True |
| | | ) |
| | | inputs = self.processor( |
| | | text=[text] * len(image_data), |
| | | images=[image_data], |
| | | padding=True, |
| | | return_tensors="pt", |
| | | ) |
| | | inputs = inputs.to("cuda:1") |
| | | current_time = datetime.now() |
| | | with torch.inference_mode(): |
| | | outputs = model.generate(**inputs, |
| | | max_new_tokens=50, |
| | | do_sample=False, |
| | | temperature=0.7, |
| | | top_k=40, |
| | | num_beams=1, |
| | | repetition_penalty= 1.1 |
| | | ) |
| | | generated_ids = outputs[:, len(inputs.input_ids[0]):] |
| | | image_text = self.processor.batch_decode( |
| | | generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True |
| | | ) |
| | | image_des = [] |
| | | for text in image_text: |
| | | image_des.append(text) |
| | | return image_des |
| | | except Exception as e: |
| | | self.logger.info(f"线程:执行图片描述时出错:{e}") |
| | | finally: |
| | | # 4. 释放模型 |
| | | self._release_model(model) |
| | | torch.cuda.empty_cache() |
| | | |
| | | def get_rule(self,ragurl): |
| | | try: |
| | | rule_text = None |
| | | search_data = { |
| | | "collection_name": "smart_rule", |
| | | "query_text": "", |
| | | "search_mode": "hybrid", |
| | | "limit": 100, |
| | | "weight_dense": 0.7, |
| | | "weight_sparse": 0.3, |
| | | "filter_expr": "", |
| | | "output_fields": ["text"] |
| | | } |
| | | response = requests.post(ragurl + "/search", json=search_data) |
| | | results = response.json().get('results') |
| | | rule_text = "" |
| | | ruleid = 1 |
| | | for rule in results: |
| | | if rule['score'] >= 0: |
| | | rule_text = rule_text + str(ruleid) + ". " + rule['entity'].get('text') + ";\n" |
| | | ruleid = ruleid + 1 |
| | | # self.logger.info(len(rule_text)) |
| | | else: |
| | | self.logger.info(f"线程:执行获取规则时出错:{response}") |
| | | return rule_text |
| | | except Exception as e: |
| | | self.logger.info(f"线程:执行获取规则时出错:{e}") |
| | | return None |
| | | |
| | | def image_rule_chat(self, image_des,rule_text, ragurl, rag_mode,max_tokens): |
| | | try: |
| | | content = ( |
| | | f"图片描述内容为:\n{image_des}\n规则内容:\n{rule_text}。\n请验证图片描述中是否有符合规则的内容,不进行推理和think。返回结果格式为[xxx符合的规则id],如果没有返回[]") |
| | | # self.logger.info(content) |
| | | #self.logger.info(len(content)) |
| | | search_data = { |
| | | "prompt": "", |
| | | "messages": [ |
| | | { |
| | | "role": "user", |
| | | "content": content |
| | | } |
| | | ], |
| | | "llm_name": rag_mode, |
| | | "stream": False, |
| | | "gen_conf": { |
| | | "temperature": 0.7, |
| | | "max_tokens": max_tokens |
| | | } |
| | | } |
| | | response = requests.post(ragurl + "/chat", json=search_data) |
| | | results = response.json().get('data') |
| | | #self.logger.info(len(results)) |
| | | # self.logger.info(results) |
| | | ret = re.sub(r'<think>.*?</think>', '', results, flags=re.DOTALL) |
| | | ret = ret.replace(" ", "").replace("\t", "").replace("\n", "") |
| | | is_waning = 0 |
| | | if len(ret) > 2: |
| | | is_waning = 1 |
| | | return is_waning |
| | | except Exception as e: |
| | | self.logger.info(f"线程:执行规则匹配时出错:{image_des, rule_text, ragurl, rag_mode,e}") |
| | | return None |
| | | |
| | | async def insert_json_data(self, ragurl, data): |
| | | try: |
| | | data = {'collection_name': "smartrag", "data": data, "description": ""} |
| | | requests.post(ragurl + "/insert_json_data", json=data, timeout=(0.3, 0.3)) |
| | | #self.logger.info(f"调用录像服务:{ragurl, data}") |
| | | except Exception as e: |
| | | #self.logger.info(f"{self._thread_name}线程:调用录像时出错:地址:{ragurl}:{e}") |
| | | return |
| | | |
| | | def _release_semaphore(self, future): |
| | | self.semaphore.release() |
| | | #self.logger.info(f"释放线程 (剩余空闲: {self.semaphore._value}/{self.max_workers})") |
| | | |
| | | def shutdown(self): |
| | | """安全关闭""" |
| | | self.executor.shutdown(wait=False) |
| | | for model in self.model_pool: |
| | | del model |
| | | torch.cuda.empty_cache() |
| | | |
| | | def _acquire_model(self): |
| | | """从池中获取一个空闲模型 (简单轮询)""" |
| | | while True: |
| | | for i, (model, lock) in enumerate(zip(self.model_pool, self.lock_pool)): |
| | | if lock.acquire(blocking=False): |
| | | return model, lock |
| | | time.sleep(0.1) # 避免CPU空转 |
| | | |
| | | def _release_model(self, model): |
| | | """释放模型回池""" |
| | | for i, m in enumerate(self.model_pool): |
| | | if m == model: |
| | | self.lock_pool[i].release() |
| | | break |
| | | |
| | | |
| | | def remove_duplicate_lines(self,text): |
| | | seen = set() |
| | | result = [] |
| | | for line in text.split('。'): # 按句号分割 |
| | | if line.strip() and line not in seen: |
| | | seen.add(line) |
| | | result.append(line) |
| | | return '。'.join(result) |
| | | def remove_duplicate_lines_d(self,text): |
| | | seen = set() |
| | | result = [] |
| | | for line in text.split(','): # 按句号分割 |
| | | if line.strip() and line not in seen: |
| | | seen.add(line) |
| | | result.append(line) |
| | | return '。'.join(result) |
| | | def remove_duplicate_lines_n(self,text): |
| | | seen = set() |
| | | result = [] |
| | | for line in text.split('\n'): # 按句号分割 |
| | | if line.strip() and line not in seen: |
| | | seen.add(line) |
| | | result.append(line) |
| | | return '。'.join(result) |
| | | |