qwen_thread.py
@@ -17,11 +17,12 @@
class qwen_thread:
    def __init__(self, config):
    def __init__(self, config,logger):
        self.config = config
        self.max_workers = int(config.get("threadnum"))
        self.executor = ThreadPoolExecutor(max_workers=int(config.get("threadnum")))
        self.semaphore = threading.Semaphore(int(config.get("threadnum")))
        self.logger = logger
        # 初始化Milvus集合
        connections.connect("default", host=config.get("milvusurl"), port=config.get("milvusport"))
@@ -45,22 +46,6 @@
        # 共享的处理器 (线程安全)
        self.processor = AutoProcessor.from_pretrained(config.get("qwenaddr"), use_fast=True)
        # 创建实例专属logger
        self.logger = logging.getLogger(f"{self.__class__}_{id(self)}")
        self.logger.setLevel(logging.INFO)
        # 避免重复添加handler
        if not self.logger.handlers:
            handler = RotatingFileHandler(
                filename=os.path.join("logs", 'thread_log.log'),
                maxBytes=10 * 1024 * 1024,
                backupCount=3,
                encoding='utf-8'
            )
            formatter = logging.Formatter(
                '%(asctime)s - %(filename)s:%(lineno)d - %(funcName)s() - %(levelname)s: %(message)s'
            )
            handler.setFormatter(formatter)
            self.logger.addHandler(handler)
    def submit(self,res_a):
        # 尝试获取信号量(非阻塞)
@@ -101,25 +86,26 @@
                # 调用规则匹配方法,判断是否预警
                is_waning = self.image_rule_chat(desc, res['waning_value'], ragurl,rag_mode,max_tokens)
                # 如果预警,则生成隐患描述和处理建议
                #if is_waning == 1:
                # 获取规章制度数据
                filedata = self.get_filedata(res['waning_value'],res['suggestion'], ragurl)
                # 生成隐患描述
                risk_description = self.image_rule_chat_with_detail(filedata, res['waning_value'], ragurl,rag_mode,max_tokens)
                # 生成处理建议
                suggestion = self.image_rule_chat_suggestion(filedata, res['waning_value'], ragurl,rag_mode,max_tokens)
                if is_waning == 1:
                    # 获取规章制度数据
                    filedata = self.get_filedata(res['waning_value'],res['suggestion'], ragurl)
                    # 生成隐患描述
                    risk_description = self.image_rule_chat_with_detail(filedata, res['waning_value'], ragurl,rag_mode,max_tokens)
                    # 生成处理建议
                    suggestion = self.image_rule_chat_suggestion(filedata, res['waning_value'], ragurl,rag_mode,max_tokens)
                    self.logger.info(
                        f"{res['video_point_id']}执行完毕:{res['id']}:是否预警{is_waning},安全隐患:{risk_description}\n处理建议:{suggestion}")
            else:
                is_desc = 3
            # 数据组
            data = {
                "id": res['id'],
                "event_level_id": res['event_level_id'],  # event_level_id
                "event_level_name": res['event_level_name'],  # event_level_id
                "rule_id": res["rule_id"],
                "video_point_id": res['video_point_id'],  # video_point_id
                "video_point_name": res['video_point_name'],
                "is_waning": 1,
                "is_waning": is_waning,
                "is_desc": is_desc,
                "zh_desc_class": desc,  # text_vector
                "bounding_box": res['bounding_box'],  # bounding_box
@@ -137,9 +123,9 @@
                "suggestion": suggestion,
                "knowledge_id": res['knowledge_id']
            }
            self.collection.delete(f"id == {res['id']}")
            # 保存到milvus
            image_id = self.collection.upsert(data).primary_keys
            image_id = self.collection.insert(data).primary_keys
            data = {
                "id": str(image_id[0]),
                "video_point_id": res['video_point_id'],
@@ -154,7 +140,7 @@
            # 调用rag
            asyncio.run(self.insert_json_data(ragurl, data))
            rag_time = datetime.now() - current_time
            self.logger.info(f"{image_id}运行结束总体用时:{datetime.now() - ks_time},图片描述用时{desc_time},RAG用时{rag_time}")
            self.logger.info(f"{res['video_point_id']}执行完毕:{image_id}运行结束总体用时:{datetime.now() - ks_time},图片描述用时{desc_time},RAG用时{rag_time}")
        except Exception as e:
            self.logger.info(f"线程:执行模型解析时出错::{e}")
            return 0
@@ -187,7 +173,7 @@
            )
            inputs = inputs.to(model.device)
            with torch.inference_mode():
                outputs = model.generate(**inputs,max_new_tokens=100)
                outputs = model.generate(**inputs,max_new_tokens=200)
            generated_ids = outputs[:, len(inputs.input_ids[0]):]
            image_text = self.processor.batch_decode(
                generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True