qwen_thread_batch.py
@@ -13,7 +13,7 @@
import re
from logging.handlers import RotatingFileHandler
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers import AutoModelForVision2Seq, AutoProcessor
class qwen_thread_batch:
@@ -30,19 +30,13 @@
        self.config = config
        self.model_pool = []
        self.lock_pool = [threading.Lock() for _ in range(max_workers)]
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True
        )
        for i in range(max_workers):
            model = AutoModelForVision2Seq.from_pretrained(
                model_path,
                device_map="cuda:1",
                device_map=f"cuda:{config.get('cuda')}",
                trust_remote_code=True,
                quantization_config=quant_config,
                use_flash_attention_2=True,
                use_safetensors=True,
                torch_dtype=torch.float16
            ).eval()
            self.model_pool.append(model)
@@ -89,20 +83,34 @@
            raise
    def tark_do(self,res_a,ragurl,rag_mode,max_tokens):
        try :
            current_time = datetime.now()
        try:
            # 1. 从集合A获取向量和元数据
            is_waning = 0
            is_desc = 2
            # 生成图片描述
            desc_list = self.image_desc(res_a)
            risk_description = ""
            suggestion = ""
            if desc_list:
                for desc,res in zip(desc_list,res_a):
                for desc, res in zip(desc_list, res_a):
                    # 图片描述生成成功
                    if desc:
                        # rule_text = self.get_rule(ragurl)
                        is_waning = self.image_rule_chat(desc,res['waning_value'],ragurl,rag_mode,max_tokens)
                        is_desc = 2
                        # 调用规则匹配方法,判断是否预警
                        is_waning = self.image_rule_chat(desc, res['waning_value'], ragurl,rag_mode,max_tokens)
                        # 如果预警,则生成隐患描述和处理建议
                        if is_waning == 1:
                            # 获取规章制度数据
                            filedata = self.get_filedata(res['waning_value'], ragurl)
                            # 生成隐患描述
                            risk_description = self.image_rule_chat_with_detail(filedata, res['waning_value'], ragurl, rag_mode)
                            # 生成处理建议
                            suggestion = self.image_rule_chat_suggestion(res['waning_value'], ragurl, rag_mode)
                    else:
                        is_waning = 0
                        is_desc = 3
                    # 数据组
                    data = {
                        "id": res['id'],
                        "event_level_id": res['event_level_id'],  # event_level_id
@@ -123,39 +131,45 @@
                        "image_path": res['image_path'],  # image_path
                        "image_desc_path": res['image_desc_path'],  # image_desc_path
                        "video_path": res['video_path'],
                        "text_vector": res['text_vector']
                        "text_vector": res['text_vector'],
                        "risk_description": risk_description,
                        "suggestion": suggestion,
                        "knowledge_id": res['knowledge_id']
                    }
                    # 保存到milvus
                    image_id = self.collection.upsert(data).primary_keys
                    #self.logger.info(f"{res['id']}--{image_id}:{desc}")
                    if is_desc == 2:
                        data = {
                            "id": str(image_id[0]),
                            "video_point_id": res['video_point_id'],
                            "video_path": res["video_point_name"],
                            "zh_desc_class": desc,
                            "detect_time": res['detect_time'],
                            "image_path": f"{res['image_path']}",
                            "task_name": res["task_name"],
                            "event_level_name": res["event_level_name"],
                            "rtsp_address": f"{res['video_path']}"
                        }
                        # 调用rag
                        asyncio.run(self.insert_json_data(ragurl, data))
            self.logger.info(f"处理完毕:{datetime.now() - current_time}:{len(res_a)}")
                    logging.info(image_id)
                    data = {
                        "id": str(image_id[0]),
                        "video_point_id": res['video_point_id'],
                        "video_path": res["video_point_name"],
                        "zh_desc_class": desc,
                        "detect_time": res['detect_time'],
                        "image_path": f"{res['image_path']}",
                        "task_name": res["task_name"],
                        "event_level_name": res["event_level_name"],
                        "rtsp_address": f"{res['video_path']}"
                    }
                    # 调用rag
                    asyncio.run(self.insert_json_data(ragurl, data))
        except Exception as e:
            self.logger.info(f"线程:执行模型解析时出错:任务:{e}")
            self.logger.info(f"线程:执行模型解析时出错::{e}")
            return 0
    def image_desc(self, res_data):
        try:
            model, lock = self._acquire_model()
            image_data = []
            for res in res_data:
                # 2. 处理图像
                image = Image.open(f"{res['image_desc_path']}").convert("RGB")  # 替换为您的图片路径
                image = image.resize((448, 448), Image.Resampling.LANCZOS)  # 高质量缩放
                image_data.append(image)
            # 1. 并行加载图像
            def _load_image(path):
                return Image.open(path).convert("RGB").resize((448, 448), Image.Resampling.LANCZOS)
            with ThreadPoolExecutor(max_workers=4) as executor:
                image_data = list(executor.map(
                    _load_image,
                    [res['image_desc_path'] for res in res_data]
                ))
            messages = [
                {
@@ -178,17 +192,9 @@
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to("cuda:1")
            current_time = datetime.now()
            inputs = inputs.to(model.device)
            with torch.inference_mode():
                outputs = model.generate(**inputs,
                                                   max_new_tokens=50,
                                                   do_sample=False,
                                                   temperature=0.7,
                                                   top_k=40,
                                                   num_beams=1,
                                                   repetition_penalty= 1.1
                                                   )
                outputs = model.generate(**inputs,max_new_tokens=100)
            generated_ids = outputs[:, len(inputs.input_ids[0]):]
            image_text = self.processor.batch_decode(
                generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
@@ -235,9 +241,9 @@
    def image_rule_chat(self, image_des,rule_text, ragurl, rag_mode,max_tokens):
        try:
            content = (
                f"图片描述内容为:\n{image_des}\n规则内容:\n{rule_text}。\n请验证图片描述中是否有符合规则的内容,不进行推理和think。返回结果格式为[xxx符合的规则id],如果没有返回[]")
            # self.logger.info(content)
            #self.logger.info(len(content))
            search_data = {
                "prompt": "",
@@ -256,7 +262,6 @@
            }
            response = requests.post(ragurl + "/chat", json=search_data)
            results = response.json().get('data')
            #self.logger.info(len(results))
            # self.logger.info(results)
            ret = re.sub(r'<think>.*?</think>', '', results, flags=re.DOTALL)
            ret = ret.replace(" ", "").replace("\t", "").replace("\n", "")
@@ -268,6 +273,102 @@
            self.logger.info(f"线程:执行规则匹配时出错:{image_des, rule_text, ragurl, rag_mode,e}")
            return None
    # 隐患描述
    def image_rule_chat_with_detail(self,filedata, rule_text, ollama_url, ollama_mode="qwen2.5vl:3b"):
        # API调用
        response = requests.post(
            # ollama地址
            url=f"{ollama_url}/chat",
            json={
                "prompt":"",
                # 请求内容
                "messages": [
                    {
                        "role": "user",
                        "content": f"请根据规章制度[{filedata}]\n查找[{rule_text}]的安全隐患描述,不进行推理和think。返回信息小于800字"
                    }
                ],
                # 指定模型
                "llm_name": "qwen3:8b",
                "stream": False,    # 关闭流式输出
                "gen_conf": {
                    "temperature": 0.7,  # 控制生成随机性
                    "max_tokens": 800   # 最大输出长度
                }
            }
        )
        # 从json提取data字段内容
        ret = response.json()["data"]
        #result = response.json()
        #ret = result.get("data") or result.get("message", {}).get("content", "")
        # 移除<think>标签和内容
        ret = re.sub(r'<think>.*?</think>', '', ret, flags=re.DOTALL)
        # 字符串清理,移除空格,制表符,换行符,星号
        ret = ret.replace(" ", "").replace("\t", "").replace("\n", "").replace("**","")
        print(ret)
        return ret
    #处理建议
    def image_rule_chat_suggestion(self, rule_text, ollama_url, ollama_mode="qwen2.5vl:3b"):
        self.logger.info("----------------------------------------------------------------")
        # 请求内容
        content = (
            f"请根据违规内容[{rule_text}]\n进行返回处理违规建议,不进行推理和think。返回精准信息")
        response = requests.post(
            # ollama地址
            url=f"{ollama_url}/chat",
            json={
                # 指定模型
                "llm_name": "qwen3:8b",
                "messages": [
                    {"role": "user", "content": content}
                ],
                "stream": False  # 关闭流式输出
            }
        )
        # 从json提取data字段内容
        ret = response.json()["data"]
        #result = response.json()
        #ret = result.get("data") or result.get("message", {}).get("content", "")
        # 移除<think>标签和内容
        ret = re.sub(r'<think>.*?</think>', '', ret, flags=re.DOTALL)
        # 字符串清理,移除空格,制表符,换行符,星号
        ret = ret.replace(" ", "").replace("\t", "").replace("\n", "").replace("**","")
        print(ret)
        return ret
    # RAG服务发送请求,获取知识库内容
    def get_filedata(self, searchtext, ragurl):
        search_data = {
            # 知识库集合
            "collection_name": "smart_knowledge",
            # 查询文本
            "query_text": searchtext,
            # 搜索模式
            "search_mode": "hybrid",
            # 最多返回结果
            "limit": 100,
            # 调密向量搜索权重
            "weight_dense": 0.7,
            # 稀疏向量搜索权重
            "weight_sparse": 0.3,
            # 空字符串
            "filter_expr": "",
            # 只返回 text 字段
            "output_fields": ["text"]
        }
        # 向 ragurl + "/search" 端点发送POST请求
        response = requests.post(ragurl + "/search", json=search_data)
        # 从响应中获取'results'字段
        results = response.json().get('results')
        # 初始化 text
        text = ""
        # 遍历所有结果规则(rule),将每条规则的'entity'中的'text'字段取出.
        for rule in results:
            text = text + rule['entity'].get('text') + ";\n"
        return text
    async def insert_json_data(self, ragurl, data):
        try:
            data = {'collection_name': "smartrag", "data": data, "description": ""}