""" Sub-query Decomposition RAG:將複雜問題拆解成子問題後檢索 """ from typing import List, Dict, Optional from .retrievers.reranker import RAGPipeline from .prompt_formatter import PromptFormatter from .llm_integration import OllamaLLM import hashlib import time import logging from concurrent.futures import ThreadPoolExecutor, as_completed logger = logging.getLogger(__name__) class SubQueryDecompositionRAG: """使用子問題拆解的 RAG 系統""" def __init__( self, rag_pipeline: RAGPipeline, llm: OllamaLLM, max_sub_queries: int = 3, top_k_per_subquery: int = 5, enable_parallel: bool = True ): """ 初始化 Sub-query Decomposition RAG Args: rag_pipeline: 現有的 RAG 管線實例 llm: LLM 實例(用於生成子問題) max_sub_queries: 最多生成的子問題數量 top_k_per_subquery: 每個子問題檢索的結果數量 enable_parallel: 是否並行處理子查詢 """ self.rag_pipeline = rag_pipeline self.llm = llm self.max_sub_queries = max_sub_queries self.top_k_per_subquery = top_k_per_subquery self.enable_parallel = enable_parallel def _generate_sub_queries(self, question: str) -> List[str]: """ 將原始問題拆解成子問題 Args: question: 原始問題 Returns: 子問題列表 """ # 檢測語言 is_chinese = PromptFormatter.detect_language(question) == "zh" if is_chinese: prompt = f"""你是一個專業助理。請將以下原始問題拆解成最多 {self.max_sub_queries} 個具體的子問題,以便進行資料搜尋。 每個子問題應專注於原始問題的一個特定面向。請以換行符號分隔問題。 原始問題: {question} 子問題清單:""" else: prompt = f"""You are a professional assistant. Please decompose the following original question into at most {self.max_sub_queries} specific sub-questions for information retrieval. Each sub-question should focus on a specific aspect of the original question. Please separate questions with newlines. Original question: {question} Sub-question list:""" try: response = self.llm.generate( prompt=prompt, temperature=0.3, # 降低溫度以獲得更穩定的結果 max_tokens=500 ) # 解析子問題 sub_queries = [ q.strip() for q in response.strip().split("\n") if q.strip() and not q.strip().startswith("#") ] # 移除編號前綴(如 "1. ", "1) " 等) cleaned_queries = [] for q in sub_queries: # 移除開頭的編號 q = q.lstrip("0123456789. )") q = q.strip() if q: cleaned_queries.append(q) # 限制數量 cleaned_queries = cleaned_queries[:self.max_sub_queries] # 如果沒有生成子問題,使用原始問題 if not cleaned_queries: logger.warning("⚠️ 未生成子問題,使用原始問題") cleaned_queries = [question] return cleaned_queries except Exception as e: logger.error(f"⚠️ 生成子問題時出錯: {e}") # 回退到原始問題 return [question] def _get_doc_id(self, doc: Dict) -> str: """ 生成文檔的唯一標識符 Args: doc: 文檔字典 Returns: 唯一 ID """ metadata = doc.get("metadata", {}) content = doc.get("content", "") # 使用 metadata 中的唯一標識(如果有的話) if "arxiv_id" in metadata and "chunk_index" in metadata: return f"{metadata['arxiv_id']}_{metadata['chunk_index']}" elif "file_path" in metadata and "chunk_index" in metadata: return f"{metadata['file_path']}_{metadata['chunk_index']}" else: # 回退到內容的 hash content_hash = hashlib.md5(content.encode()).hexdigest()[:16] return f"doc_{content_hash}" def _retrieve_for_subquery( self, sub_query: str, metadata_filter: Optional[Dict] = None ) -> List[Dict]: """ 針對單個子問題進行檢索 Args: sub_query: 子問題 metadata_filter: 可選的 metadata 過濾條件 Returns: 檢索結果列表 """ try: results = self.rag_pipeline.query( text=sub_query, top_k=self.top_k_per_subquery, metadata_filter=metadata_filter, enable_rerank=True ) return results except Exception as e: logger.error(f"⚠️ 檢索子問題 '{sub_query}' 時出錯: {e}") return [] def _get_unique_documents( self, sub_queries: List[str], metadata_filter: Optional[Dict] = None ) -> List[Dict]: """ 針對所有子問題進行檢索,並移除重複的檔案 Args: sub_queries: 子問題列表 metadata_filter: 可選的 metadata 過濾條件 Returns: 去重後的文檔列表 """ unique_docs = {} if self.enable_parallel and len(sub_queries) > 1: # 並行處理子查詢 logger.info(f"🔄 並行處理 {len(sub_queries)} 個子查詢...") with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor: future_to_query = { executor.submit(self._retrieve_for_subquery, q, metadata_filter): q for q in sub_queries } for future in as_completed(future_to_query): sub_query = future_to_query[future] try: docs = future.result() logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(docs)} 個結果") for doc in docs: doc_id = self._get_doc_id(doc) if doc_id not in unique_docs: unique_docs[doc_id] = doc else: # 如果已存在,保留分數更高的 existing_score = unique_docs[doc_id].get( 'rerank_score', unique_docs[doc_id].get('hybrid_score', 0) ) new_score = doc.get( 'rerank_score', doc.get('hybrid_score', 0) ) if new_score > existing_score: unique_docs[doc_id] = doc except Exception as e: logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}") else: # 串行處理 logger.info(f"🔄 串行處理 {len(sub_queries)} 個子查詢...") for sub_query in sub_queries: docs = self._retrieve_for_subquery(sub_query, metadata_filter) logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(docs)} 個結果") for doc in docs: doc_id = self._get_doc_id(doc) if doc_id not in unique_docs: unique_docs[doc_id] = doc else: # 保留分數更高的 existing_score = unique_docs[doc_id].get( 'rerank_score', unique_docs[doc_id].get('hybrid_score', 0) ) new_score = doc.get( 'rerank_score', doc.get('hybrid_score', 0) ) if new_score > existing_score: unique_docs[doc_id] = doc # 按分數排序 result_list = list(unique_docs.values()) result_list.sort( key=lambda x: x.get('rerank_score', x.get('hybrid_score', 0)), reverse=True ) return result_list def query( self, question: str, top_k: int = 5, metadata_filter: Optional[Dict] = None, return_sub_queries: bool = False ) -> Dict: """ 執行 Sub-query Decomposition RAG 查詢 Args: question: 原始問題 top_k: 返回前 k 個結果 metadata_filter: 可選的 metadata 過濾條件 return_sub_queries: 是否在結果中包含子問題列表 Returns: 包含檢索結果和統計資訊的字典 """ start_time = time.time() # 第一步:產生子問題 logger.info(f"🔍 拆解問題: '{question}'") sub_queries = self._generate_sub_queries(question) logger.info(f"✅ 生成 {len(sub_queries)} 個子問題:") for i, sq in enumerate(sub_queries, 1): logger.info(f" {i}. {sq}") # 第二步:檢索並去重 logger.info(f"📚 檢索相關文檔...") docs = self._get_unique_documents(sub_queries, metadata_filter) logger.info(f"✅ 找到 {len(docs)} 個唯一文檔(去重後)") # 第三步:返回前 top_k 個結果 final_results = docs[:top_k] elapsed_time = time.time() - start_time result = { "results": final_results, "total_docs_found": len(docs), "sub_queries": sub_queries if return_sub_queries else None, "elapsed_time": elapsed_time } return result def generate_answer( self, question: str, formatter: PromptFormatter, top_k: int = 5, metadata_filter: Optional[Dict] = None, document_type: str = "general", return_sub_queries: bool = False ) -> Dict: """ 完整的 Sub-query Decomposition RAG 流程:檢索 + 生成答案 Args: question: 原始問題 formatter: Prompt 格式化器 top_k: 返回前 k 個結果用於生成答案 metadata_filter: 可選的 metadata 過濾條件 document_type: 文檔類型 ("paper", "cv", "general") return_sub_queries: 是否在結果中包含子問題列表 Returns: 包含檢索結果、生成的答案和統計資訊的字典 """ # 檢索 retrieval_result = self.query( question=question, top_k=top_k, metadata_filter=metadata_filter, return_sub_queries=return_sub_queries ) if not retrieval_result["results"]: return { **retrieval_result, "answer": "抱歉,未找到相關文檔來回答此問題。", "formatted_context": None } # 格式化上下文 formatted_context = formatter.format_context( retrieval_result["results"], document_type=document_type ) # 創建 prompt prompt = formatter.create_prompt( question, formatted_context, document_type=document_type ) # 生成回答 logger.info("🤖 生成回答中...") answer_start = time.time() try: answer = self.llm.generate( prompt=prompt, temperature=0.7, max_tokens=2048 ) answer_time = time.time() - answer_start logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)") except Exception as e: logger.error(f"❌ 生成回答時出錯: {e}") answer = f"生成回答時出錯: {e}" answer_time = time.time() - answer_start return { **retrieval_result, "answer": answer, "formatted_context": formatted_context, "answer_time": answer_time, "total_time": retrieval_result["elapsed_time"] + answer_time }