#2025/7/10
#完善知识库部分,按配置的知识库从rag中获取数据;优化生成安全隐患和处理建议的提示语
2个文件已修改
164 ■■■■■ 已修改文件
qwen_detect.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
qwen_thread.py 147 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
qwen_detect.py
@@ -1,5 +1,4 @@
from operator import itemgetter
import torch
import threading
import time as time_sel
from typing import Dict
@@ -7,7 +6,6 @@
import requests
import os
import logging
from transformers import AutoProcessor, AutoModelForVision2Seq
from pymilvus import connections, Collection
from logging.handlers import RotatingFileHandler
import get_mem
@@ -43,7 +41,7 @@
        # 加载集合
        self.collection = Collection(name="smartobject")
        self.collection.load()
        self.pool = qwen_thread(int(self.config.get("threadnum")), self.config,self.config.get("qwenaddr"))
        self.pool = qwen_thread(self.config)
        #是否更新
        self._isupdate = False
@@ -96,13 +94,12 @@
                    output_fields=["id", "zh_desc_class", "text_vector", "bounding_box", "video_point_name", "task_id",
                                   "task_name", "event_level_id", "event_level_name",
                                   "video_point_id", "detect_num", "is_waning", "is_desc", "waning_value", "rule_id",
                                   "detect_id","knowledge_id",
                                   "detect_id","knowledge_id","suggestion",
                                   "detect_time", "image_path", "image_desc_path", "video_path"],
                    consistency_level="Strong",
                    order_by_field="id",  # 按id字段排序
                    order_by_type="desc"  # 降序排列
                )
                # 读取共享内存中的图片
                # image_id = get_mem.smem_read_frame_qianwen(camera_id)
                if len(res_a) > 0:
@@ -132,19 +129,17 @@
                            "image_desc_path": res['image_desc_path'],  # image_desc_path
                            "video_path": res['video_path'],
                            "text_vector": res['text_vector'],
                            "knowledge_id": res['knowledge_id']
                            "knowledge_id": res['knowledge_id'],
                            "suggestion": res['suggestion'],
                        }
                        # logging.info(f"读取图像成功: {res['id']}")
                        # 保存到milvus
                        image_id = self.collection.upsert(data).primary_keys
                        res['id'] = image_id[0]
                        # logging.info(f"读取图像成功: {image_id}")
                        image_id = self.pool.submit(res)
                        # image_id = pool.tark_do(image_id,self.config.get("ragurl"),self.config.get("ragmode"),self.config.get("max_tokens"))
                        # logging.info(f"处理图像成功: {image_id}")
                    sorted_results = None
                        self.pool.submit(res)
            except Exception as e:
                logging.info(f"{camera_id}线程错误:{e}")
            time_sel.sleep(0.01)
    #调用是否需要更新
    def isUpdate(self):
qwen_thread.py
@@ -17,31 +17,33 @@
class qwen_thread:
    def __init__(self, max_workers,config,model_path):
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
        self.semaphore = threading.Semaphore(max_workers)
        self.max_workers = max_workers
    def __init__(self, config):
        self.config = config
        self.max_workers = int(config.get("threadnum"))
        self.executor = ThreadPoolExecutor(max_workers=int(config.get("threadnum")))
        self.semaphore = threading.Semaphore(int(config.get("threadnum")))
        # 初始化Milvus集合
        connections.connect("default", host=config.get("milvusurl"), port=config.get("milvusport"))
        # 加载集合
        self.collection = Collection(name="smartobject")
        self.collection.load()
        self.config = config
        self.model_pool = []
        self.lock_pool = [threading.Lock() for _ in range(max_workers)]
        for i in range(max_workers):
        self.lock_pool = [threading.Lock() for _ in range(int(config.get("threadnum")))]
        for i in range(int(config.get("threadnum"))):
            model = AutoModelForVision2Seq.from_pretrained(
                model_path,
                config.get("qwenaddr"),
                device_map=f"cuda:{config.get('cuda')}",
                trust_remote_code=True,
                use_safetensors=True,
                torch_dtype=torch.float16
            ).eval()
            self.model_pool.append(model)
        # 共享的处理器 (线程安全)
        self.processor = AutoProcessor.from_pretrained(model_path,use_fast=True)
        self.processor = AutoProcessor.from_pretrained(config.get("qwenaddr"), use_fast=True)
        # 创建实例专属logger
        self.logger = logging.getLogger(f"{self.__class__}_{id(self)}")
@@ -65,7 +67,7 @@
        acquired = self.semaphore.acquire(blocking=False)
        if not acquired:
            self.logger.info(f"线程池已满,等待空闲线程... (当前活跃: {self.max_workers - self.semaphore._value}/{self.max_workers})")
            #self.logger.info(f"线程池已满,等待空闲线程... (当前活跃: {self.max_workers - self.semaphore._value}/{self.max_workers})")
            # 阻塞等待直到有可用线程
            self.semaphore.acquire(blocking=True)
@@ -73,15 +75,11 @@
        future.add_done_callback(self._release_semaphore)
        return future
    def _wrap_task(self, res):
    def _wrap_task(self, res_a):
        try:
            #self.logger.info(f"处理: { res['id']}开始")
            current_time = datetime.now()
            image_id = self.tark_do(res, self.config.get("ragurl"), self.config.get("ragmode"), self.config.get("max_tokens"))
            self.logger.info(f"处理: { res['id']}完毕{image_id}:{datetime.now() - current_time}")
            return image_id
            self.tark_do(res_a, self.config.get("ragurl"), self.config.get("ragmode"), self.config.get("max_tokens"))
        except Exception as e:
            self.logger.info(f"任务 { res['id']} 处理出错: {e}")
            self.logger.info(f"处理出错: {e}")
            raise
    def tark_do(self,res,ragurl,rag_mode,max_tokens):
@@ -91,23 +89,25 @@
            is_desc = 2
            # 生成图片描述
            image_des = self.image_desc(res['image_desc_path'])
            ks_time = datetime.now()
            desc = self.image_desc(res)
            desc_time = datetime.now() - ks_time
            current_time = datetime.now()
            risk_description = ""
            suggestion = ""
            # 图片描述生成成功
            if image_des:
            if desc:
                is_desc = 2
                # 调用规则匹配方法,判断是否预警
                is_waning = self.image_rule_chat(image_des, res['waning_value'], ragurl, rag_mode, max_tokens)
                is_waning = self.image_rule_chat(desc, res['waning_value'], ragurl,rag_mode,max_tokens)
                # 如果预警,则生成隐患描述和处理建议
                if is_waning == 1:
                #if is_waning == 1:
                    # 获取规章制度数据
                    filedata = self.get_filedata(res['waning_value'], ragurl)
                filedata = self.get_filedata(res['waning_value'],res['suggestion'], ragurl)
                    # 生成隐患描述
                    risk_description = self.image_rule_chat_with_detail(filedata, res['waning_value'], ragurl, rag_mode)
                risk_description = self.image_rule_chat_with_detail(filedata, res['waning_value'], ragurl,rag_mode,max_tokens)
                    # 生成处理建议
                    suggestion = self.image_rule_chat_suggestion(res['waning_value'], ragurl, rag_mode)
                suggestion = self.image_rule_chat_suggestion(filedata, res['waning_value'], ragurl,rag_mode,max_tokens)
            else:
                is_desc = 3
@@ -119,9 +119,9 @@
                "rule_id": res["rule_id"],
                "video_point_id": res['video_point_id'],  # video_point_id
                "video_point_name": res['video_point_name'],
                "is_waning": is_waning,
                "is_waning": 1,
                "is_desc": is_desc,
                "zh_desc_class": image_des,  # text_vector
                "zh_desc_class": desc,  # text_vector
                "bounding_box": res['bounding_box'],  # bounding_box
                "task_id": res['task_id'],  # task_id
                "task_name": res['task_name'],  # task_id
@@ -140,12 +140,11 @@
            # 保存到milvus
            image_id = self.collection.upsert(data).primary_keys
            logging.info(image_id)
            data = {
                "id": str(image_id[0]),
                "video_point_id": res['video_point_id'],
                "video_path": res["video_point_name"],
                "zh_desc_class": image_des,
                "zh_desc_class": desc,
                "detect_time": res['detect_time'],
                "image_path": f"{res['image_path']}",
                "task_name": res["task_name"],
@@ -154,17 +153,17 @@
            }
            # 调用rag
            asyncio.run(self.insert_json_data(ragurl, data))
            return image_id
            rag_time = datetime.now() - current_time
            self.logger.info(f"{image_id}运行结束总体用时:{datetime.now() - ks_time},图片描述用时{desc_time},RAG用时{rag_time}")
        except Exception as e:
            self.logger.info(f"线程:执行模型解析时出错:任务:{res['id']} :{e}")
            self.logger.info(f"线程:执行模型解析时出错::{e}")
            return 0
    def image_desc(self, image_path):
    def image_desc(self, res_data):
        try:
            model, lock = self._acquire_model()
            # 2. 处理图像
            image = Image.open(image_path).convert("RGB")  # 替换为您的图片路径
            image = image.resize((600, 600), Image.Resampling.LANCZOS)  # 高质量缩放
            image = Image.open(res_data['image_desc_path']).convert("RGB").resize((600, 600), Image.Resampling.LANCZOS)
            messages = [
                {
                    "role": "user",
@@ -187,14 +186,8 @@
                return_tensors="pt",
            )
            inputs = inputs.to(model.device)
            current_time = datetime.now()
            outputs = model.generate(**inputs,
                                               max_new_tokens=300,
                                               do_sample=True,
                                               temperature=0.7,
                                               renormalize_logits=True
                                               )
            print(f"处理完毕:{datetime.now() - current_time}")
            with torch.inference_mode():
                outputs = model.generate(**inputs,max_new_tokens=100)
            generated_ids = outputs[:, len(inputs.input_ids[0]):]
            image_text = self.processor.batch_decode(
                generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
@@ -241,7 +234,6 @@
        try:
            content = (
                f"图片描述内容为:\n{image_des}\n规则内容:\n{rule_text}。\n请验证图片描述中是否有符合规则的内容,不进行推理和think。返回结果格式为[xxx符合的规则id],如果没有返回[]")
            # self.logger.info(content)
            #self.logger.info(len(content))
            search_data = {
                "prompt": "",
@@ -260,7 +252,6 @@
            }
            response = requests.post(ragurl + "/chat", json=search_data)
            results = response.json().get('data')
            #self.logger.info(len(results))
            # self.logger.info(results)
            ret = re.sub(r'<think>.*?</think>', '', results, flags=re.DOTALL)
            ret = ret.replace(" ", "").replace("\t", "").replace("\n", "")
@@ -273,72 +264,67 @@
            return None
    # 隐患描述
    def image_rule_chat_with_detail(self, filedata, rule_text, ollama_url, ollama_mode="qwen2.5vl:3b"):
    def image_rule_chat_with_detail(self,filedata, rule_text, ragurl, rag_mode,max_tokens):
        # API调用
        response = requests.post(
            # ollama地址
            url=f"{ollama_url}/chat",
            json={
        content = (
            f"规章制度为:[{filedata}]\n违反内容为:[{rule_text}]\n请查询违反内容在规章制度中的安全隐患,不进行推理和think,返回简短的文字信息")
        # self.logger.info(len(content))
        search_data = {
                "prompt": "",
                # 请求内容
                "messages": [
                    {
                        "role": "user",
                        "content": f"请根据规章制度[{filedata}]\n查找[{rule_text}]的安全隐患描述,不进行推理和think。返回信息小于800字"
                    "content": content
                    }
                ],
                # 指定模型
                "llm_name": "qwen3:8b",
                "stream": False,  # 关闭流式输出
            "llm_name": rag_mode,
            "stream": False,
                "gen_conf": {
                    "temperature": 0.7,  # 控制生成随机性
                    "max_tokens": 800  # 最大输出长度
                "temperature": 0.7,
                "max_tokens": max_tokens
                }
            }
        )
        response = requests.post(ragurl + "/chat", json=search_data)
        # 从json提取data字段内容
        ret = response.json()["data"]
        # result = response.json()
        # ret = result.get("data") or result.get("message", {}).get("content", "")
        # 移除<think>标签和内容
        ret = re.sub(r'<think>.*?</think>', '', ret, flags=re.DOTALL)
        # 字符串清理,移除空格,制表符,换行符,星号
        ret = ret.replace(" ", "").replace("\t", "").replace("\n", "").replace("**", "")
        print(ret)
        #print(f"安全隐患:{ret}")
        return ret
    # 处理建议
    def image_rule_chat_suggestion(self, rule_text, ollama_url, ollama_mode="qwen2.5vl:3b"):
        self.logger.info("----------------------------------------------------------------")
    def image_rule_chat_suggestion(self,filedata, rule_text, ragurl, rag_mode,max_tokens):
        # 请求内容
        content = (
            f"请根据违规内容[{rule_text}]\n进行返回处理违规建议,不进行推理和think。返回精准信息")
            f"规章制度为:[{filedata}]\n违反内容为:[{rule_text}]\n请查询违反内容在规章制度中的处理建议,不进行推理和think,返回简短的文字信息")
        response = requests.post(
            # ollama地址
            url=f"{ollama_url}/chat",
            url=f"{ragurl}/chat",
            json={
                # 指定模型
                "llm_name": "qwen3:8b",
                "llm_name": rag_mode,
                "messages": [
                    {"role": "user", "content": content}
                ],
                "stream": False  # 关闭流式输出
                "stream": False,  # 关闭流式输出
                "gen_conf": {
                    "temperature": 0.7,
                    "max_tokens": max_tokens
                }
            }
        )
        # 从json提取data字段内容
        ret = response.json()["data"]
        # result = response.json()
        # ret = result.get("data") or result.get("message", {}).get("content", "")
        # 移除<think>标签和内容
        ret = re.sub(r'<think>.*?</think>', '', ret, flags=re.DOTALL)
        # 字符串清理,移除空格,制表符,换行符,星号
        ret = ret.replace(" ", "").replace("\t", "").replace("\n", "").replace("**", "")
        print(ret)
        #print(f"处理建议:{ret}")
        return ret
    # RAG服务发送请求,获取知识库内容
    def get_filedata(self, searchtext, ragurl):
    def get_filedata(self, searchtext,filter_expr, ragurl):
        search_data = {
            # 知识库集合
            "collection_name": "smart_knowledge",
@@ -347,16 +333,17 @@
            # 搜索模式
            "search_mode": "hybrid",
            # 最多返回结果
            "limit": 100,
            "limit": 10,
            # 调密向量搜索权重
            "weight_dense": 0.7,
            "weight_dense": 0.9,
            # 稀疏向量搜索权重
            "weight_sparse": 0.3,
            "weight_sparse": 0.1,
            # 空字符串
            "filter_expr": "",
            "filter_expr": f"docnm_kwd in {filter_expr}",
            # 只返回 text 字段
            "output_fields": ["text"]
        }
        #print(search_data)
        # 向 ragurl + "/search" 端点发送POST请求
        response = requests.post(ragurl + "/search", json=search_data)
        # 从响应中获取'results'字段
@@ -366,7 +353,7 @@
        # 遍历所有结果规则(rule),将每条规则的'entity'中的'text'字段取出.
        for rule in results:
            text = text + rule['entity'].get('text') + ";\n"
        #print(text)
        return text
    async def insert_json_data(self, ragurl, data):
@@ -380,7 +367,7 @@
    def _release_semaphore(self, future):
        self.semaphore.release()
        self.logger.info(f"释放线程 (剩余空闲: {self.semaphore._value}/{self.max_workers})")
        #self.logger.info(f"释放线程 (剩余空闲: {self.semaphore._value}/{self.max_workers})")
    def shutdown(self):
        """安全关闭"""