| | |
| | | import time |
| | | from concurrent.futures import ThreadPoolExecutor |
| | | import threading |
| | | |
| | | import torch |
| | | from PIL import Image |
| | | from pymilvus import connections, Collection |
| | | from datetime import datetime |
| | | import os |
| | | import requests |
| | | import asyncio |
| | | import logging |
| | | import re |
| | | from logging.handlers import RotatingFileHandler |
| | | |
| | | from qwen_vl_utils import process_vision_info |
| | | from transformers import AutoModelForVision2Seq, AutoProcessor |
| | | |
| | | |
| | |
| | | # 加载集合 |
| | | self.collection = Collection(name="smartobject") |
| | | self.collection.load() |
| | | |
| | | if config.get('cuda') == None or config.get('cuda') == '0': |
| | | self.device = f"cuda" |
| | | else: |
| | | self.device = f"cuda:{config.get('cuda')}" |
| | | self.model_pool = [] |
| | | self.lock_pool = [threading.Lock() for _ in range(int(config.get("threadnum")))] |
| | | for i in range(int(config.get("threadnum"))): |
| | | model = AutoModelForVision2Seq.from_pretrained( |
| | | config.get("qwenaddr"), |
| | | device_map=f"cuda:{config.get('cuda')}", |
| | | device_map=self.device, |
| | | trust_remote_code=True, |
| | | use_safetensors=True, |
| | | torch_dtype=torch.float16 |
| | | |
| | | ).eval() |
| | | model = model.to(f"cuda:{config.get('cuda')}") |
| | | model = model.to(self.device) |
| | | self.model_pool.append(model) |
| | | |
| | | # 共享的处理器 (线程安全) |
| | |
| | | # 1. 从集合A获取向量和元数据 |
| | | is_waning = 0 |
| | | is_desc = 2 |
| | | |
| | | # 生成图片描述 |
| | | ks_time = datetime.now() |
| | | desc = self.image_desc(res) |
| | | desc_time = datetime.now() - ks_time |
| | | current_time = datetime.now() |
| | | risk_description = "" |
| | | suggestion = "" |
| | | # 图片描述生成成功 |
| | | if desc: |
| | | is_desc = 2 |
| | | # 调用规则匹配方法,判断是否预警 |
| | | is_waning = self.image_rule_chat(desc, res['waning_value'], ragurl,rag_mode,max_tokens) |
| | | is_waning = self.image_rule(res) |
| | | # 如果预警,则生成隐患描述和处理建议 |
| | | if is_waning == 1: |
| | | # 获取规章制度数据 |
| | |
| | | risk_description = self.image_rule_chat_with_detail(filedata, res['waning_value'], ragurl,rag_mode,max_tokens) |
| | | # 生成处理建议 |
| | | suggestion = self.image_rule_chat_suggestion(filedata, res['waning_value'], ragurl,rag_mode,max_tokens) |
| | | self.logger.info( |
| | | f"{res['video_point_id']}执行完毕:{res['id']}:是否预警{is_waning},安全隐患:{risk_description}\n处理建议:{suggestion}") |
| | | #self.logger.info(f"{res['video_point_id']}执行完毕:{res['id']}:是否预警{is_waning},安全隐患:{risk_description}\n处理建议:{suggestion}") |
| | | # 数据组 |
| | | data = { |
| | | "event_level_id": res['event_level_id'], # event_level_id |
| | | "event_level_name": res['event_level_name'], # event_level_id |
| | | "rule_id": res["rule_id"], |
| | | "video_point_id": res['video_point_id'], # video_point_id |
| | | "video_point_name": res['video_point_name'], |
| | | "is_waning": is_waning, |
| | | "is_desc": 1, |
| | | "zh_desc_class": res['zh_desc_class'], # text_vector |
| | | "bounding_box": res['bounding_box'], # bounding_box |
| | | "task_id": res['task_id'], # task_id |
| | | "task_name": res['task_name'], # task_id |
| | | "detect_id": res['detect_id'], # detect_id |
| | | "detect_time": res['detect_time'], # detect_time |
| | | "detect_num": res['detect_num'], |
| | | "waning_value": res['waning_value'], |
| | | "image_path": res['image_path'], # image_path |
| | | "image_desc_path": res['image_desc_path'], # image_desc_path |
| | | "video_path": res['video_path'], |
| | | "text_vector": res['text_vector'], |
| | | "risk_description": risk_description, |
| | | "suggestion": suggestion, |
| | | "knowledge_id": res['knowledge_id'] |
| | | } |
| | | self.collection.delete(f"id == {res['id']}") |
| | | # 保存到milvus |
| | | image_id = self.collection.insert(data).primary_keys |
| | | res['id'] = image_id[0] |
| | | # 图片描述生成成功 |
| | | desc = self.image_desc(res) |
| | | if desc: |
| | | is_desc = 2 |
| | | else: |
| | | is_desc = 3 |
| | | |
| | | # 数据组 |
| | | data = { |
| | | "event_level_id": res['event_level_id'], # event_level_id |
| | |
| | | # 调用rag |
| | | asyncio.run(self.insert_json_data(ragurl, data)) |
| | | rag_time = datetime.now() - current_time |
| | | self.logger.info(f"{res['video_point_id']}执行完毕:{image_id}运行结束总体用时:{datetime.now() - ks_time},图片描述用时{desc_time},RAG用时{rag_time}") |
| | | self.logger.info(f"{res['video_point_id']}执行完毕:{image_id}运行结束总体用时:{datetime.now() - ks_time},图片描述用时{desc_time},RAG用时{rag_time}") |
| | | if is_waning == 1: |
| | | self.logger.info(f"{res['video_point_id']}执行完毕:{image_id},图片描述:{desc}\n隐患:{risk_description}\n建议:{suggestion}") |
| | | except Exception as e: |
| | | self.logger.info(f"线程:执行模型解析时出错::{e}") |
| | | return 0 |
| | |
| | | return_tensors="pt", |
| | | ) |
| | | inputs = inputs.to(model.device) |
| | | with torch.inference_mode(),torch.cuda.amp.autocast(): |
| | | outputs = model.generate(**inputs,max_new_tokens=200) |
| | | with torch.inference_mode(), torch.amp.autocast(device_type=self.device, dtype=torch.float16): |
| | | outputs = model.generate(**inputs,max_new_tokens=200,do_sample=False,num_beams=1,temperature=None,top_p=None,top_k=1,use_cache=True,repetition_penalty=1.0) |
| | | generated_ids = outputs[:, len(inputs.input_ids[0]):] |
| | | image_text = self.processor.batch_decode( |
| | | generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True |
| | | ) |
| | | image_des = (image_text[0]).strip() |
| | | #self.logger.info(f"{res_data['video_point_id']}:{res_data['id']}:{res_data['detect_time']}:{image_des}") |
| | | return image_des |
| | | except Exception as e: |
| | | self.logger.info(f"线程:执行图片描述时出错:{e}") |
| | | finally: |
| | | # 4. 释放模型 |
| | | self._release_model(model) |
| | | torch.cuda.empty_cache() |
| | | |
| | | def image_rule(self, res_data): |
| | | try: |
| | | model, lock = self._acquire_model() |
| | | image = Image.open(res_data['image_desc_path']).convert("RGB").resize((600, 600), Image.Resampling.LANCZOS) |
| | | |
| | | messages = [ |
| | | { |
| | | "role": "user", |
| | | "content": [ |
| | | {"type": "image", "image": image}, |
| | | {"type": "text", "text": f"图片中是否有{res_data['waning_value']}?请回答yes或no"}, |
| | | ], |
| | | } |
| | | ] |
| | | |
| | | # Preparation for inference |
| | | text = self.processor.apply_chat_template( |
| | | messages, tokenize=False, add_generation_prompt=True |
| | | ) |
| | | image_inputs, video_inputs = process_vision_info(messages) |
| | | inputs = self.processor( |
| | | text=[text], |
| | | images=image_inputs, |
| | | videos=video_inputs, |
| | | padding=True, |
| | | return_tensors="pt", |
| | | ) |
| | | inputs = inputs.to(model.device) |
| | | |
| | | with torch.no_grad(): |
| | | outputs = model.generate(**inputs, max_new_tokens=10) |
| | | generated_ids = outputs[:, len(inputs.input_ids[0]):] |
| | | image_text = self.processor.batch_decode( |
| | | generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True |
| | | ) |
| | | |
| | | image_des = (image_text[0]).strip() |
| | | return image_des |
| | | except Exception as e: |
| | |
| | | } |
| | | response = requests.post(ragurl + "/chat", json=search_data) |
| | | results = response.json().get('data') |
| | | # self.logger.info(results) |
| | | ret = re.sub(r'<think>.*?</think>', '', results, flags=re.DOTALL) |
| | | ret = ret.replace(" ", "").replace("\t", "").replace("\n", "") |
| | | #self.logger.info(f"{rule_text}:{ret}") |
| | | is_waning = 0 |
| | | if len(ret) > 2: |
| | | is_waning = 1 |
| | |
| | | "max_tokens": max_tokens |
| | | } |
| | | } |
| | | #self.logger.info(content) |
| | | response = requests.post(ragurl + "/chat", json=search_data) |
| | | # 从json提取data字段内容 |
| | | ret = response.json()["data"] |