Spaces:
Sleeping
Sleeping
Commit ·
979763a
1
Parent(s): ee8f3fd
combine src of advanced RAG
Browse files- deep_agent_rag/rag/private_file_rag.py +26 -161
- src/__init__.py +37 -0
- src/document_processor.py +590 -0
- src/hybrid_subquery_hyde_rag.py +399 -0
- src/hyde_rag.py +235 -0
- src/llm_integration.py +246 -0
- src/prompt_formatter.py +395 -0
- src/retrievers/__init__.py +17 -0
- src/retrievers/base.py +32 -0
- src/retrievers/bm25_retriever.py +127 -0
- src/retrievers/hybrid_search.py +298 -0
- src/retrievers/reranker.py +448 -0
- src/retrievers/vector_retriever.py +254 -0
- src/step_back_rag.py +305 -0
- src/subquery_rag.py +361 -0
- src/triple_hybrid_rag.py +467 -0
deep_agent_rag/rag/private_file_rag.py
CHANGED
|
@@ -24,34 +24,26 @@ from langchain_core.messages import HumanMessage
|
|
| 24 |
from .llm_adapter import LangChainLLMAdapter
|
| 25 |
from .adaptive_rag_selector import AdaptiveRAGSelector, RAGMethod
|
| 26 |
|
| 27 |
-
# 添加
|
| 28 |
-
#
|
| 29 |
current_file = Path(__file__).resolve()
|
| 30 |
# 從 deep_agent_rag/rag/private_file_rag.py 向上找到 Deep_Agentic_AI_Tool 根目錄
|
|
|
|
| 31 |
deep_agent_root = current_file.parent.parent.parent.parent
|
| 32 |
-
learn_rag_path = deep_agent_root.parent / "Learn_RAG"
|
| 33 |
-
|
| 34 |
-
# 如果 Learn_RAG 不在預期位置,嘗試其他可能的位置
|
| 35 |
-
if not learn_rag_path.exists():
|
| 36 |
-
# 嘗試當前工作目錄的父目錄
|
| 37 |
-
cwd = Path.cwd()
|
| 38 |
-
learn_rag_path = cwd.parent / "Learn_RAG"
|
| 39 |
-
|
| 40 |
-
if not learn_rag_path.exists():
|
| 41 |
-
# 嘗試直接使用絕對路徑
|
| 42 |
-
learn_rag_path = Path("/Users/matthuang/Desktop/Learn_RAG")
|
| 43 |
|
| 44 |
-
#
|
| 45 |
-
|
| 46 |
-
if
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
| 51 |
else:
|
| 52 |
-
print(f"⚠️ 無法找到
|
| 53 |
-
print(f"
|
| 54 |
-
print(f"
|
| 55 |
|
| 56 |
# 嘗試導入 Learn_RAG 模組
|
| 57 |
# 注意:document_processor.py 在頂層導入了 arxiv,所以需要先安裝依賴
|
|
@@ -78,13 +70,13 @@ try:
|
|
| 78 |
|
| 79 |
if missing_deps:
|
| 80 |
print(f"⚠️ 缺少以下依賴包: {', '.join(missing_deps)}")
|
| 81 |
-
print(f"\n💡 請安裝
|
| 82 |
print(f" 方法 1: 使用 pip")
|
| 83 |
print(f" pip install {' '.join(missing_deps)}")
|
| 84 |
-
print(f"\n 方法 2: 使用 uv (推薦
|
| 85 |
-
print(f" cd {
|
| 86 |
print(f" uv sync")
|
| 87 |
-
print(f"\n 方法 3: 安裝所有
|
| 88 |
print(f" pip install arxiv langchain-community langchain-text-splitters chromadb sentence-transformers rank-bm25 pypdf docx2txt langchain-experimental")
|
| 89 |
LEARN_RAG_AVAILABLE = False
|
| 90 |
else:
|
|
@@ -104,22 +96,23 @@ try:
|
|
| 104 |
# 不再需要導入 OllamaLLM,因為我們使用 Deep_Agentic_AI_Tool 的統一 LLM 系統(get_llm())
|
| 105 |
# from src.llm_integration import OllamaLLM
|
| 106 |
LEARN_RAG_AVAILABLE = True
|
| 107 |
-
print("✓ 成功導入
|
| 108 |
|
| 109 |
except ImportError as e:
|
| 110 |
error_msg = str(e)
|
| 111 |
-
print(f"⚠️ 無法導入
|
| 112 |
-
print(f"\n💡 請安裝
|
| 113 |
print(f" pip install arxiv langchain-community langchain-text-splitters chromadb sentence-transformers rank-bm25 pypdf docx2txt langchain-experimental")
|
| 114 |
print(f"\n 或者:")
|
| 115 |
-
print(f" cd {
|
| 116 |
print(f" uv sync")
|
| 117 |
LEARN_RAG_AVAILABLE = False
|
| 118 |
except Exception as e:
|
| 119 |
error_msg = str(e)
|
| 120 |
-
print(f"⚠️ 導入
|
| 121 |
print(f" 當前 Python 路徑: {sys.path[:3]}")
|
| 122 |
-
print(f"
|
|
|
|
| 123 |
LEARN_RAG_AVAILABLE = False
|
| 124 |
|
| 125 |
|
|
@@ -1236,134 +1229,6 @@ def reset_private_rag_instance():
|
|
| 1236 |
"""重置全局實例"""
|
| 1237 |
global _private_rag_instance
|
| 1238 |
_private_rag_instance = None
|
| 1239 |
-
|
| 1240 |
-
|
| 1241 |
-
"""
|
| 1242 |
-
私有文件 RAG 系統
|
| 1243 |
-
集成 Learn_RAG 的功能,支持上傳私有文件(PDF、DOCX、TXT)並使用 RAG 回答問題
|
| 1244 |
-
|
| 1245 |
-
LLM 使用策略:
|
| 1246 |
-
- 優先使用 Groq API(如果配置了 API 金鑰)
|
| 1247 |
-
- 其次使用 Ollama(如果服務正在運行)
|
| 1248 |
-
- 最後使用 MLX 本地模型(作為備選方案)
|
| 1249 |
-
"""
|
| 1250 |
-
import os
|
| 1251 |
-
import sys
|
| 1252 |
-
import time
|
| 1253 |
-
from pathlib import Path
|
| 1254 |
-
from typing import Optional, Dict, List, Tuple
|
| 1255 |
-
import tempfile
|
| 1256 |
-
import shutil
|
| 1257 |
-
|
| 1258 |
-
# 導入 Deep_Agentic_AI_Tool 的 LLM 工具
|
| 1259 |
-
# 這樣可以使用統一的 LLM 優先順序策略(Groq -> Ollama -> MLX)
|
| 1260 |
-
from ..utils.llm_utils import get_llm
|
| 1261 |
-
from langchain_core.messages import HumanMessage
|
| 1262 |
-
|
| 1263 |
-
# 導入 LLM 適配器和智能選擇器
|
| 1264 |
-
from .llm_adapter import LangChainLLMAdapter
|
| 1265 |
-
from .adaptive_rag_selector import AdaptiveRAGSelector, RAGMethod
|
| 1266 |
-
|
| 1267 |
-
# 添加 Learn_RAG 到 Python 路徑
|
| 1268 |
-
# 計算 Learn_RAG 的路徑(與 Deep_Agentic_AI_Tool 在同一目錄下)
|
| 1269 |
-
current_file = Path(__file__).resolve()
|
| 1270 |
-
# 從 deep_agent_rag/rag/private_file_rag.py 向上找到 Deep_Agentic_AI_Tool 根目錄
|
| 1271 |
-
deep_agent_root = current_file.parent.parent.parent.parent
|
| 1272 |
-
learn_rag_path = deep_agent_root.parent / "Learn_RAG"
|
| 1273 |
-
|
| 1274 |
-
# 如果 Learn_RAG 不在預期位置,嘗試其他可能的位置
|
| 1275 |
-
if not learn_rag_path.exists():
|
| 1276 |
-
# 嘗試當前工作目錄的父目錄
|
| 1277 |
-
cwd = Path.cwd()
|
| 1278 |
-
learn_rag_path = cwd.parent / "Learn_RAG"
|
| 1279 |
-
|
| 1280 |
-
if not learn_rag_path.exists():
|
| 1281 |
-
# 嘗試直接使用絕對路徑
|
| 1282 |
-
learn_rag_path = Path("/Users/matthuang/Desktop/Learn_RAG")
|
| 1283 |
-
|
| 1284 |
-
# 將 Learn_RAG 目錄添加到 Python 路徑(這樣可以導入 src 模組)
|
| 1285 |
-
# 注意:需要將 Learn_RAG 目錄本身添加到路徑,因為 src 模組在 Learn_RAG/src/ 下
|
| 1286 |
-
if learn_rag_path.exists() and learn_rag_path.is_dir():
|
| 1287 |
-
if str(learn_rag_path) not in sys.path:
|
| 1288 |
-
sys.path.insert(0, str(learn_rag_path))
|
| 1289 |
-
print(f"✓ 找到 Learn_RAG 項目: {learn_rag_path}")
|
| 1290 |
-
print(f" Python 路徑已添加: {learn_rag_path}")
|
| 1291 |
-
else:
|
| 1292 |
-
print(f"⚠️ 無法找到 Learn_RAG 項目")
|
| 1293 |
-
print(f" 嘗試的路徑: {learn_rag_path}")
|
| 1294 |
-
print(f" 請確保 Learn_RAG 項目在: {deep_agent_root.parent / 'Learn_RAG'}")
|
| 1295 |
-
|
| 1296 |
-
# 嘗試導入 Learn_RAG 模組
|
| 1297 |
-
# 注意:document_processor.py 在頂層導入了 arxiv,所以需要先安裝依賴
|
| 1298 |
-
try:
|
| 1299 |
-
# 先檢查必要的依賴是否已安裝
|
| 1300 |
-
import importlib
|
| 1301 |
-
|
| 1302 |
-
required_deps = {
|
| 1303 |
-
"arxiv": "arxiv",
|
| 1304 |
-
"langchain_community": "langchain-community",
|
| 1305 |
-
"langchain_text_splitters": "langchain-text-splitters",
|
| 1306 |
-
"chromadb": "chromadb",
|
| 1307 |
-
"sentence_transformers": "sentence-transformers",
|
| 1308 |
-
"rank_bm25": "rank-bm25",
|
| 1309 |
-
"pypdf": "pypdf",
|
| 1310 |
-
}
|
| 1311 |
-
|
| 1312 |
-
missing_deps = []
|
| 1313 |
-
for module_name, package_name in required_deps.items():
|
| 1314 |
-
try:
|
| 1315 |
-
importlib.import_module(module_name)
|
| 1316 |
-
except ImportError:
|
| 1317 |
-
missing_deps.append(package_name)
|
| 1318 |
-
|
| 1319 |
-
if missing_deps:
|
| 1320 |
-
print(f"⚠️ 缺少以下依賴包: {', '.join(missing_deps)}")
|
| 1321 |
-
print(f"\n💡 請安裝 Learn_RAG 項目的依賴:")
|
| 1322 |
-
print(f" 方法 1: 使用 pip")
|
| 1323 |
-
print(f" pip install {' '.join(missing_deps)}")
|
| 1324 |
-
print(f"\n 方法 2: 使用 uv (推薦,如果 Learn_RAG 使用 uv)")
|
| 1325 |
-
print(f" cd {learn_rag_path}")
|
| 1326 |
-
print(f" uv sync")
|
| 1327 |
-
print(f"\n 方法 3: 安裝所有 Learn_RAG 依賴")
|
| 1328 |
-
print(f" pip install arxiv langchain-community langchain-text-splitters chromadb sentence-transformers rank-bm25 pypdf docx2txt langchain-experimental")
|
| 1329 |
-
LEARN_RAG_AVAILABLE = False
|
| 1330 |
-
else:
|
| 1331 |
-
# 所有依賴都已安裝,嘗試導入模組
|
| 1332 |
-
from src.document_processor import DocumentProcessor
|
| 1333 |
-
from src.retrievers.bm25_retriever import BM25Retriever
|
| 1334 |
-
from src.retrievers.vector_retriever import VectorRetriever
|
| 1335 |
-
from src.retrievers.hybrid_search import HybridSearch
|
| 1336 |
-
from src.retrievers.reranker import Reranker, RAGPipeline
|
| 1337 |
-
from src.prompt_formatter import PromptFormatter
|
| 1338 |
-
# 導入進階 RAG 方法
|
| 1339 |
-
from src.subquery_rag import SubQueryDecompositionRAG
|
| 1340 |
-
from src.hyde_rag import HyDERAG
|
| 1341 |
-
from src.step_back_rag import StepBackRAG
|
| 1342 |
-
from src.hybrid_subquery_hyde_rag import HybridSubqueryHyDERAG
|
| 1343 |
-
from src.triple_hybrid_rag import TripleHybridRAG
|
| 1344 |
-
# 不再需要導入 OllamaLLM,因為我們使用 Deep_Agentic_AI_Tool 的統一 LLM 系統(get_llm())
|
| 1345 |
-
# from src.llm_integration import OllamaLLM
|
| 1346 |
-
LEARN_RAG_AVAILABLE = True
|
| 1347 |
-
print("✓ 成功導入 Learn_RAG 模組(包含進階 RAG 方法)")
|
| 1348 |
-
|
| 1349 |
-
except ImportError as e:
|
| 1350 |
-
error_msg = str(e)
|
| 1351 |
-
print(f"⚠️ 無法導入 Learn_RAG 模組: {error_msg}")
|
| 1352 |
-
print(f"\n💡 請安裝 Learn_RAG 項目的依賴:")
|
| 1353 |
-
print(f" pip install arxiv langchain-community langchain-text-splitters chromadb sentence-transformers rank-bm25 pypdf docx2txt langchain-experimental")
|
| 1354 |
-
print(f"\n 或者:")
|
| 1355 |
-
print(f" cd {learn_rag_path}")
|
| 1356 |
-
print(f" uv sync")
|
| 1357 |
-
LEARN_RAG_AVAILABLE = False
|
| 1358 |
-
except Exception as e:
|
| 1359 |
-
error_msg = str(e)
|
| 1360 |
-
print(f"⚠️ 導入 Learn_RAG 模組時發生錯誤: {error_msg}")
|
| 1361 |
-
print(f" 當前 Python 路徑: {sys.path[:3]}")
|
| 1362 |
-
print(f" Learn_RAG 路徑: {learn_rag_path}")
|
| 1363 |
-
LEARN_RAG_AVAILABLE = False
|
| 1364 |
-
|
| 1365 |
-
|
| 1366 |
-
class PrivateFileRAG:
|
| 1367 |
"""
|
| 1368 |
私有文件 RAG 系統管理器
|
| 1369 |
|
|
|
|
| 24 |
from .llm_adapter import LangChainLLMAdapter
|
| 25 |
from .adaptive_rag_selector import AdaptiveRAGSelector, RAGMethod
|
| 26 |
|
| 27 |
+
# 添加項目根目錄到 Python 路徑(這樣可以導入 src 模組)
|
| 28 |
+
# 從 deep_agent_rag/rag/private_file_rag.py 向上找到 Deep_Agentic_AI_Tool 根目錄
|
| 29 |
current_file = Path(__file__).resolve()
|
| 30 |
# 從 deep_agent_rag/rag/private_file_rag.py 向上找到 Deep_Agentic_AI_Tool 根目錄
|
| 31 |
+
# private_file_rag.py -> rag/ -> deep_agent_rag/ -> Deep_Agentic_AI_Tool/
|
| 32 |
deep_agent_root = current_file.parent.parent.parent.parent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
# 檢查 src 目錄是否存在(應該在項目根目錄下)
|
| 35 |
+
src_path = deep_agent_root / "src"
|
| 36 |
+
if src_path.exists() and src_path.is_dir():
|
| 37 |
+
# 將項目根目錄添加到 Python 路徑(不是 src 目錄本身)
|
| 38 |
+
# 這樣可以通過 from src.xxx import xxx 導入
|
| 39 |
+
if str(deep_agent_root) not in sys.path:
|
| 40 |
+
sys.path.insert(0, str(deep_agent_root))
|
| 41 |
+
print(f"✓ 找到本地 src 模組: {src_path}")
|
| 42 |
+
print(f" 項目根目錄已添加到 Python 路徑: {deep_agent_root}")
|
| 43 |
else:
|
| 44 |
+
print(f"⚠️ 無法找到 src 目錄")
|
| 45 |
+
print(f" 預期路徑: {src_path}")
|
| 46 |
+
print(f" 項目根目錄: {deep_agent_root}")
|
| 47 |
|
| 48 |
# 嘗試導入 Learn_RAG 模組
|
| 49 |
# 注意:document_processor.py 在頂層導入了 arxiv,所以需要先安裝依賴
|
|
|
|
| 70 |
|
| 71 |
if missing_deps:
|
| 72 |
print(f"⚠️ 缺少以下依賴包: {', '.join(missing_deps)}")
|
| 73 |
+
print(f"\n💡 請安裝 RAG 系統所需的依賴:")
|
| 74 |
print(f" 方法 1: 使用 pip")
|
| 75 |
print(f" pip install {' '.join(missing_deps)}")
|
| 76 |
+
print(f"\n 方法 2: 使用 uv (推薦)")
|
| 77 |
+
print(f" cd {deep_agent_root}")
|
| 78 |
print(f" uv sync")
|
| 79 |
+
print(f"\n 方法 3: 安裝所有依賴")
|
| 80 |
print(f" pip install arxiv langchain-community langchain-text-splitters chromadb sentence-transformers rank-bm25 pypdf docx2txt langchain-experimental")
|
| 81 |
LEARN_RAG_AVAILABLE = False
|
| 82 |
else:
|
|
|
|
| 96 |
# 不再需要導入 OllamaLLM,因為我們使用 Deep_Agentic_AI_Tool 的統一 LLM 系統(get_llm())
|
| 97 |
# from src.llm_integration import OllamaLLM
|
| 98 |
LEARN_RAG_AVAILABLE = True
|
| 99 |
+
print("✓ 成功導入 RAG 模組(本地集成版本,包含進階 RAG 方法)")
|
| 100 |
|
| 101 |
except ImportError as e:
|
| 102 |
error_msg = str(e)
|
| 103 |
+
print(f"⚠️ 無法導入 RAG 模組: {error_msg}")
|
| 104 |
+
print(f"\n💡 請安裝 RAG 系統所需的依賴:")
|
| 105 |
print(f" pip install arxiv langchain-community langchain-text-splitters chromadb sentence-transformers rank-bm25 pypdf docx2txt langchain-experimental")
|
| 106 |
print(f"\n 或者:")
|
| 107 |
+
print(f" cd {deep_agent_root}")
|
| 108 |
print(f" uv sync")
|
| 109 |
LEARN_RAG_AVAILABLE = False
|
| 110 |
except Exception as e:
|
| 111 |
error_msg = str(e)
|
| 112 |
+
print(f"⚠️ 導入 RAG 模組時發生錯誤: {error_msg}")
|
| 113 |
print(f" 當前 Python 路徑: {sys.path[:3]}")
|
| 114 |
+
print(f" 項目根目錄: {deep_agent_root}")
|
| 115 |
+
print(f" src 目錄: {src_path}")
|
| 116 |
LEARN_RAG_AVAILABLE = False
|
| 117 |
|
| 118 |
|
|
|
|
| 1229 |
"""重置全局實例"""
|
| 1230 |
global _private_rag_instance
|
| 1231 |
_private_rag_instance = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1232 |
"""
|
| 1233 |
私有文件 RAG 系統管理器
|
| 1234 |
|
src/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RAG 系統模組套件
|
| 3 |
+
"""
|
| 4 |
+
from .document_processor import DocumentProcessor
|
| 5 |
+
from .retrievers import (
|
| 6 |
+
BaseRetriever,
|
| 7 |
+
BM25Retriever,
|
| 8 |
+
VectorRetriever,
|
| 9 |
+
HybridSearch,
|
| 10 |
+
Reranker,
|
| 11 |
+
RAGPipeline,
|
| 12 |
+
)
|
| 13 |
+
from .prompt_formatter import PromptFormatter
|
| 14 |
+
from .llm_integration import OllamaLLM
|
| 15 |
+
from .subquery_rag import SubQueryDecompositionRAG
|
| 16 |
+
from .hyde_rag import HyDERAG
|
| 17 |
+
from .hybrid_subquery_hyde_rag import HybridSubqueryHyDERAG
|
| 18 |
+
from .step_back_rag import StepBackRAG
|
| 19 |
+
from .triple_hybrid_rag import TripleHybridRAG
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"DocumentProcessor",
|
| 23 |
+
"BaseRetriever",
|
| 24 |
+
"BM25Retriever",
|
| 25 |
+
"VectorRetriever",
|
| 26 |
+
"HybridSearch",
|
| 27 |
+
"Reranker",
|
| 28 |
+
"RAGPipeline",
|
| 29 |
+
"PromptFormatter",
|
| 30 |
+
"OllamaLLM",
|
| 31 |
+
"SubQueryDecompositionRAG",
|
| 32 |
+
"HyDERAG",
|
| 33 |
+
"HybridSubqueryHyDERAG",
|
| 34 |
+
"StepBackRAG",
|
| 35 |
+
"TripleHybridRAG",
|
| 36 |
+
]
|
| 37 |
+
|
src/document_processor.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
文檔處理模組:載入 arXiv 論文並進行文字分割
|
| 3 |
+
支援本地檔案:PDF, DOCX, TXT
|
| 4 |
+
支援兩種分塊策略:
|
| 5 |
+
1. 字符分塊(預設):基於固定字符數的分塊,速度快
|
| 6 |
+
2. 語義分塊(可選):基於語義相似度的分塊,能保持語義完整性
|
| 7 |
+
"""
|
| 8 |
+
from typing import List, Dict, Optional, Any
|
| 9 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import os
|
| 12 |
+
import arxiv
|
| 13 |
+
import re
|
| 14 |
+
|
| 15 |
+
# 嘗試導入語義分塊器(需要 langchain-experimental)
|
| 16 |
+
try:
|
| 17 |
+
from langchain_experimental.text_splitter import SemanticChunker
|
| 18 |
+
SEMANTIC_CHUNKER_AVAILABLE = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
SEMANTIC_CHUNKER_AVAILABLE = False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DocumentProcessor:
|
| 24 |
+
"""
|
| 25 |
+
處理 arXiv 論文文檔,進行分割和準備
|
| 26 |
+
|
| 27 |
+
支援兩種分塊模式:
|
| 28 |
+
- 字符分塊(預設):快速、穩定,適合大多數場景
|
| 29 |
+
- 語義分塊(可選):更智能,能保持語義完整性,但需要額外依賴和計算時間
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
chunk_size: int = 1000,
|
| 35 |
+
chunk_overlap: int = 200,
|
| 36 |
+
embeddings: Optional[Any] = None, # 可選:用於語義分塊的 embedding 模型
|
| 37 |
+
use_semantic_chunking: bool = False, # 是否使用語義分塊
|
| 38 |
+
breakpoint_threshold_amount: float = 1.5, # 語義分塊敏感度(標準差倍數)
|
| 39 |
+
min_chunk_size: int = 100 # 語義分塊的最小 chunk 大小(字符數)
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
初始化文檔處理器
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
chunk_size: 每個 chunk 的大小(字符數),僅用於字符分塊模式
|
| 46 |
+
chunk_overlap: chunk 之間的重疊大小(字符數),僅用於字符分塊模式
|
| 47 |
+
embeddings: 用於計算語義距離的 embedding 模型物件(可選)
|
| 48 |
+
當 use_semantic_chunking=True 時必須提供
|
| 49 |
+
use_semantic_chunking: 是否使用語義分塊
|
| 50 |
+
True: 使用語義分塊(需要提供 embeddings)
|
| 51 |
+
False: 使用字符分塊(預設)
|
| 52 |
+
breakpoint_threshold_amount: 語義分塊的敏感度參數
|
| 53 |
+
數值越大,分塊越少(chunks 越大)
|
| 54 |
+
數值越小,分塊越多(chunks 越小)
|
| 55 |
+
建議範圍:1.0 - 2.0,預設 1.5
|
| 56 |
+
min_chunk_size: 語義分塊的最小 chunk 大小(字符數)
|
| 57 |
+
小於此大小的 chunks 會被合併到相鄰的 chunks
|
| 58 |
+
預設 100 字符
|
| 59 |
+
"""
|
| 60 |
+
self.embeddings = embeddings
|
| 61 |
+
self.use_semantic_chunking = use_semantic_chunking
|
| 62 |
+
self.min_chunk_size = min_chunk_size
|
| 63 |
+
|
| 64 |
+
# 如果要求使用語義分塊
|
| 65 |
+
if use_semantic_chunking:
|
| 66 |
+
# 檢查是否安裝了必要的依賴
|
| 67 |
+
if not SEMANTIC_CHUNKER_AVAILABLE:
|
| 68 |
+
raise ImportError(
|
| 69 |
+
"使用語義分塊需要安裝 langchain-experimental 套件。\n"
|
| 70 |
+
"請執行: pip install langchain-experimental\n"
|
| 71 |
+
"或使用字符分塊模式(use_semantic_chunking=False)"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# 檢查是否提供了 embeddings
|
| 75 |
+
if embeddings is None:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
"使用語義分塊時必須提供 embeddings 參數。\n"
|
| 78 |
+
"範例:\n"
|
| 79 |
+
" from langchain_community.embeddings import HuggingFaceEmbeddings\n"
|
| 80 |
+
" embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')\n"
|
| 81 |
+
" processor = DocumentProcessor(embeddings=embeddings, use_semantic_chunking=True)"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# 初始化語義分塊器
|
| 85 |
+
# 使用「標準差」策略:當相鄰句子之間的語義距離超過平均距離的標準差倍數時,進行切分
|
| 86 |
+
self.text_splitter = SemanticChunker(
|
| 87 |
+
embeddings,
|
| 88 |
+
breakpoint_threshold_type="standard_deviation",
|
| 89 |
+
breakpoint_threshold_amount=breakpoint_threshold_amount
|
| 90 |
+
)
|
| 91 |
+
print(f"✓ 使用語義分塊模式(敏感度: {breakpoint_threshold_amount},最小 chunk 大小: {min_chunk_size} 字符)")
|
| 92 |
+
else:
|
| 93 |
+
# 使用傳統的字符分塊(預設模式)
|
| 94 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 95 |
+
chunk_size=chunk_size,
|
| 96 |
+
chunk_overlap=chunk_overlap,
|
| 97 |
+
length_function=len,
|
| 98 |
+
)
|
| 99 |
+
print(f"✓ 使用字符分塊模式(大小: {chunk_size} 字符,重疊: {chunk_overlap} 字符)")
|
| 100 |
+
|
| 101 |
+
def _post_process_chunks(self, chunks: List[str]) -> List[str]:
|
| 102 |
+
"""
|
| 103 |
+
後處理 chunks:過濾和合併太小的 chunks
|
| 104 |
+
|
| 105 |
+
語義分塊可能會產生一些非��小的 chunks(例如只有幾個單詞),
|
| 106 |
+
這些小 chunks 可能不包含足夠的上下文資訊。此方法會:
|
| 107 |
+
1. 將小於 min_chunk_size 的 chunks 合併到相鄰的 chunks
|
| 108 |
+
2. 確保最終的 chunks 都有足夠的大小
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
chunks: 原始 chunks 列表(從分塊器產生的)
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
處理後的 chunks 列表(過濾和合併後的)
|
| 115 |
+
"""
|
| 116 |
+
# 如果使用字符分塊,不需要後處理(因為已經有固定大小)
|
| 117 |
+
if not self.use_semantic_chunking:
|
| 118 |
+
return chunks
|
| 119 |
+
|
| 120 |
+
# 如果沒有 chunks,直接返回
|
| 121 |
+
if not chunks:
|
| 122 |
+
return chunks
|
| 123 |
+
|
| 124 |
+
processed = []
|
| 125 |
+
current_small_chunk = "" # 累積的小 chunk
|
| 126 |
+
|
| 127 |
+
for chunk in chunks:
|
| 128 |
+
chunk_stripped = chunk.strip()
|
| 129 |
+
chunk_length = len(chunk_stripped)
|
| 130 |
+
|
| 131 |
+
# 如果當前 chunk 太小,嘗試與下一個合併
|
| 132 |
+
if chunk_length < self.min_chunk_size:
|
| 133 |
+
# 累積到臨時變數中
|
| 134 |
+
if current_small_chunk:
|
| 135 |
+
current_small_chunk += "\n\n" + chunk
|
| 136 |
+
else:
|
| 137 |
+
current_small_chunk = chunk
|
| 138 |
+
else:
|
| 139 |
+
# 當前 chunk 足夠大
|
| 140 |
+
# 如果有累積的小 chunk,先處理它
|
| 141 |
+
if current_small_chunk:
|
| 142 |
+
current_small_chunk_stripped = current_small_chunk.strip()
|
| 143 |
+
if len(current_small_chunk_stripped) >= self.min_chunk_size:
|
| 144 |
+
# 累積後足夠大,作為獨立 chunk
|
| 145 |
+
processed.append(current_small_chunk)
|
| 146 |
+
else:
|
| 147 |
+
# 累積後還是太小,合併到上一個 chunk(如果存在)
|
| 148 |
+
if processed:
|
| 149 |
+
processed[-1] += "\n\n" + current_small_chunk
|
| 150 |
+
else:
|
| 151 |
+
# 如果沒有上一個 chunk,還是要保留
|
| 152 |
+
processed.append(current_small_chunk)
|
| 153 |
+
current_small_chunk = ""
|
| 154 |
+
|
| 155 |
+
# 添加當前足夠大的 chunk
|
| 156 |
+
processed.append(chunk)
|
| 157 |
+
|
| 158 |
+
# 處理最後的累積小 chunk
|
| 159 |
+
if current_small_chunk:
|
| 160 |
+
current_small_chunk_stripped = current_small_chunk.strip()
|
| 161 |
+
if len(current_small_chunk_stripped) >= self.min_chunk_size:
|
| 162 |
+
# 足夠大,作為獨立 chunk
|
| 163 |
+
processed.append(current_small_chunk)
|
| 164 |
+
elif processed:
|
| 165 |
+
# 太小,合併到最後一個 chunk
|
| 166 |
+
processed[-1] += "\n\n" + current_small_chunk
|
| 167 |
+
else:
|
| 168 |
+
# 如果沒有其他 chunks,還是要保留
|
| 169 |
+
processed.append(current_small_chunk)
|
| 170 |
+
|
| 171 |
+
return processed
|
| 172 |
+
|
| 173 |
+
def fetch_papers(self, query: str, max_results: int = 10) -> List[Dict]:
|
| 174 |
+
"""
|
| 175 |
+
從 arXiv 獲取論文
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
query: 搜尋查詢(例如 "cat:cs.AI")
|
| 179 |
+
max_results: 最大結果數量
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
論文列表,每個論文包含標題、摘要等資訊
|
| 183 |
+
"""
|
| 184 |
+
search = arxiv.Search(
|
| 185 |
+
query=query,
|
| 186 |
+
max_results=max_results,
|
| 187 |
+
sort_by=arxiv.SortCriterion.SubmittedDate
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
papers = []
|
| 191 |
+
for paper in search.results():
|
| 192 |
+
papers.append({
|
| 193 |
+
"title": paper.title,
|
| 194 |
+
"authors": [author.name for author in paper.authors],
|
| 195 |
+
"summary": paper.summary,
|
| 196 |
+
"published": str(paper.published),
|
| 197 |
+
"arxiv_id": paper.entry_id.split('/')[-1],
|
| 198 |
+
"arxiv_url": paper.entry_id,
|
| 199 |
+
"pdf_url": paper.pdf_url,
|
| 200 |
+
"categories": paper.categories,
|
| 201 |
+
})
|
| 202 |
+
|
| 203 |
+
return papers
|
| 204 |
+
|
| 205 |
+
def process_documents(self, papers: List[Dict]) -> List[Dict]:
|
| 206 |
+
"""
|
| 207 |
+
處理論文,將每篇論文分割成 chunks
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
papers: 論文列表
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
處理後的文檔 chunks,每個 chunk 包含內容和元數據
|
| 214 |
+
"""
|
| 215 |
+
documents = []
|
| 216 |
+
|
| 217 |
+
for paper in papers:
|
| 218 |
+
# 組合論文的完整文字(標題 + 摘要)
|
| 219 |
+
# 保留換行符號 \n\n 作為語義斷點的結構參考
|
| 220 |
+
full_text = f"Title: {paper['title']}\n\nAbstract: {paper['summary']}"
|
| 221 |
+
|
| 222 |
+
# 分割文字(根據選擇的模式:字符分塊或語義分塊)
|
| 223 |
+
chunks = self.text_splitter.split_text(full_text)
|
| 224 |
+
|
| 225 |
+
# 後處理:過濾和合併太小的 chunks(僅語義分塊模式)
|
| 226 |
+
chunks = self._post_process_chunks(chunks)
|
| 227 |
+
|
| 228 |
+
# 為每個 chunk 創建文檔物件
|
| 229 |
+
for i, chunk in enumerate(chunks):
|
| 230 |
+
doc = {
|
| 231 |
+
"content": chunk,
|
| 232 |
+
"metadata": {
|
| 233 |
+
"title": paper['title'],
|
| 234 |
+
"arxiv_id": paper['arxiv_id'],
|
| 235 |
+
"arxiv_url": paper['arxiv_url'],
|
| 236 |
+
"pdf_url": paper['pdf_url'],
|
| 237 |
+
"authors": paper['authors'],
|
| 238 |
+
"published": paper['published'],
|
| 239 |
+
"categories": paper['categories'],
|
| 240 |
+
"chunk_index": i,
|
| 241 |
+
"total_chunks": len(chunks),
|
| 242 |
+
"chunking_method": "semantic" if self.use_semantic_chunking else "character"
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
documents.append(doc)
|
| 246 |
+
|
| 247 |
+
return documents
|
| 248 |
+
|
| 249 |
+
def get_texts_and_metadatas(self, documents: List[Dict]):
|
| 250 |
+
"""
|
| 251 |
+
從文檔列表中提取文字和元數據
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
documents: 文檔列表
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
(texts, metadatas) 元組
|
| 258 |
+
"""
|
| 259 |
+
texts = [doc["content"] for doc in documents]
|
| 260 |
+
metadatas = [doc["metadata"] for doc in documents]
|
| 261 |
+
return texts, metadatas
|
| 262 |
+
|
| 263 |
+
@staticmethod
|
| 264 |
+
def clean_extracted_text(text: str) -> str:
|
| 265 |
+
"""
|
| 266 |
+
清理從 PDF/DOCX 提取的文本,移除多餘的空格和修復字符換行問題
|
| 267 |
+
|
| 268 |
+
某些 PDF 提取工具會在每個字符之間插入空格或換行,特別是中文文本。
|
| 269 |
+
此方法會:
|
| 270 |
+
1. 修復「每個字符一行」的問題(將單字符行合併)
|
| 271 |
+
2. 移除中文字符之間的多餘空格
|
| 272 |
+
3. 保留英文單詞之間的空格
|
| 273 |
+
4. 保留標點符號周圍的適當空格
|
| 274 |
+
5. 保留真正的段落分隔
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
text: 原始提取的文本
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
清理後的文本
|
| 281 |
+
"""
|
| 282 |
+
if not text:
|
| 283 |
+
return text
|
| 284 |
+
|
| 285 |
+
# 步驟 0: 修復「每個字符一行」的問題
|
| 286 |
+
# 檢測模式:每行只有一個字符(可能是中文字符、標點、或單個字母/數字)
|
| 287 |
+
# 將這些單字符行合併成連續文本
|
| 288 |
+
lines = text.split('\n')
|
| 289 |
+
merged_lines = []
|
| 290 |
+
i = 0
|
| 291 |
+
|
| 292 |
+
def is_single_char_line(line: str) -> bool:
|
| 293 |
+
"""
|
| 294 |
+
判斷是否為單字符行
|
| 295 |
+
考慮:去除空格後長度 <= 3(可能是單字符+標點,或單字符+空格)
|
| 296 |
+
"""
|
| 297 |
+
stripped = line.strip()
|
| 298 |
+
if not stripped:
|
| 299 |
+
return False # 空行不算
|
| 300 |
+
# 如果去除空格後長度 <= 3,且主要是中文字符、標點或單個字母/數字
|
| 301 |
+
if len(stripped) <= 3:
|
| 302 |
+
# 檢查是否主要是單個字符(可能帶標點或空格)
|
| 303 |
+
# 移除所有空格後,如果長度 <= 2,認為是單字符行
|
| 304 |
+
no_space = stripped.replace(' ', '')
|
| 305 |
+
if len(no_space) <= 2:
|
| 306 |
+
return True
|
| 307 |
+
return False
|
| 308 |
+
|
| 309 |
+
while i < len(lines):
|
| 310 |
+
line = lines[i]
|
| 311 |
+
stripped_line = line.strip()
|
| 312 |
+
|
| 313 |
+
# 如果當前行是單字符行
|
| 314 |
+
if is_single_char_line(line):
|
| 315 |
+
# 收集連續的單字符行(包括空行,因為空行可能是分隔符)
|
| 316 |
+
merged_chars = []
|
| 317 |
+
j = i
|
| 318 |
+
consecutive_single_chars = 0
|
| 319 |
+
|
| 320 |
+
while j < len(lines):
|
| 321 |
+
current_line = lines[j]
|
| 322 |
+
current_stripped = current_line.strip()
|
| 323 |
+
|
| 324 |
+
if is_single_char_line(current_line):
|
| 325 |
+
# 是單字符行,收集字符(去除空格)
|
| 326 |
+
char = current_stripped.replace(' ', '')
|
| 327 |
+
if char:
|
| 328 |
+
merged_chars.append(char)
|
| 329 |
+
consecutive_single_chars += 1
|
| 330 |
+
j += 1
|
| 331 |
+
elif not current_stripped:
|
| 332 |
+
# 空行:如果前面有單字符,且後面可能還有單字符,跳過空行
|
| 333 |
+
# 檢查下一行是否也是單字符
|
| 334 |
+
if j + 1 < len(lines) and is_single_char_line(lines[j + 1]):
|
| 335 |
+
# 空行後面還有單字符,跳過空行繼續收集
|
| 336 |
+
j += 1
|
| 337 |
+
else:
|
| 338 |
+
# 空行後面沒有單字符了,停止收集
|
| 339 |
+
break
|
| 340 |
+
else:
|
| 341 |
+
# 遇到正常行,停止收集
|
| 342 |
+
break
|
| 343 |
+
|
| 344 |
+
# 如果收集到多個單字符,合併它們
|
| 345 |
+
if len(merged_chars) > 1:
|
| 346 |
+
merged_text = ''.join(merged_chars)
|
| 347 |
+
merged_lines.append(merged_text)
|
| 348 |
+
i = j
|
| 349 |
+
continue
|
| 350 |
+
elif len(merged_chars) == 1 and consecutive_single_chars > 1:
|
| 351 |
+
# 只有一個字符但有多行(可能是空格導致的),也合併
|
| 352 |
+
merged_text = ''.join(merged_chars)
|
| 353 |
+
merged_lines.append(merged_text)
|
| 354 |
+
i = j
|
| 355 |
+
continue
|
| 356 |
+
else:
|
| 357 |
+
# 只有一個單字符,且確實只有一行,保留原樣
|
| 358 |
+
if merged_chars:
|
| 359 |
+
merged_lines.append(merged_chars[0])
|
| 360 |
+
i = j
|
| 361 |
+
continue
|
| 362 |
+
else:
|
| 363 |
+
# 正常行,直接添加
|
| 364 |
+
if stripped_line: # 非空行
|
| 365 |
+
merged_lines.append(stripped_line)
|
| 366 |
+
i += 1
|
| 367 |
+
|
| 368 |
+
# 重新組合文本
|
| 369 |
+
text = '\n'.join(merged_lines)
|
| 370 |
+
|
| 371 |
+
# 步驟 0.5: 再次處理可能的殘留問題
|
| 372 |
+
# 如果還有單字符行(可能是第一次處理遺漏的),再次處理
|
| 373 |
+
lines = text.split('\n')
|
| 374 |
+
final_lines = []
|
| 375 |
+
i = 0
|
| 376 |
+
while i < len(lines):
|
| 377 |
+
line = lines[i].strip()
|
| 378 |
+
if is_single_char_line(line):
|
| 379 |
+
# 再次收集連續的單字符行
|
| 380 |
+
merged_chars = []
|
| 381 |
+
j = i
|
| 382 |
+
while j < len(lines) and is_single_char_line(lines[j]):
|
| 383 |
+
char = lines[j].strip().replace(' ', '')
|
| 384 |
+
if char:
|
| 385 |
+
merged_chars.append(char)
|
| 386 |
+
j += 1
|
| 387 |
+
|
| 388 |
+
if len(merged_chars) > 1:
|
| 389 |
+
final_lines.append(''.join(merged_chars))
|
| 390 |
+
i = j
|
| 391 |
+
else:
|
| 392 |
+
if merged_chars:
|
| 393 |
+
final_lines.append(merged_chars[0])
|
| 394 |
+
i = j
|
| 395 |
+
else:
|
| 396 |
+
if line:
|
| 397 |
+
final_lines.append(line)
|
| 398 |
+
i += 1
|
| 399 |
+
|
| 400 |
+
text = '\n'.join(final_lines)
|
| 401 |
+
|
| 402 |
+
# 1. 移除中文字符之間的空格
|
| 403 |
+
# 匹配模式:中文字符 + 空格 + 中文字符
|
| 404 |
+
chinese_char_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])'
|
| 405 |
+
text = re.sub(chinese_char_pattern, r'\1\2', text)
|
| 406 |
+
|
| 407 |
+
# 2. 移除中文和標點符號之間的多餘空格
|
| 408 |
+
# 中文 + 空格 + 標點符號
|
| 409 |
+
chinese_punct_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([,。、;:!?""''()【】《》])'
|
| 410 |
+
text = re.sub(chinese_punct_pattern, r'\1\2', text)
|
| 411 |
+
|
| 412 |
+
# 標點符號 + 空格 + 中文
|
| 413 |
+
# 使用 re.escape 來正確處理標點符號,避免轉義序列警告
|
| 414 |
+
punct_chars = ',。、;:!?""''()【】《》'
|
| 415 |
+
punct_chinese_pattern = f'([{re.escape(punct_chars)}])\\s+([\\u4e00-\\u9fff\\u3400-\\u4dbf\\uf900-\\ufaff])'
|
| 416 |
+
text = re.sub(punct_chinese_pattern, r'\1\2', text)
|
| 417 |
+
|
| 418 |
+
# 3. 移除數字和中文之間的多餘空格(例如:"500 公里" -> "500公里")
|
| 419 |
+
number_chinese_pattern = r'(\d+)\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])'
|
| 420 |
+
text = re.sub(number_chinese_pattern, r'\1\2', text)
|
| 421 |
+
chinese_number_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+(\d+)'
|
| 422 |
+
text = re.sub(chinese_number_pattern, r'\1\2', text)
|
| 423 |
+
|
| 424 |
+
# 4. 移除英文單詞內部的多餘空格(例如:"Nebula-X 跨次 元量" -> "Nebula-X 跨次元量")
|
| 425 |
+
# 但保留英文單詞之間的空格
|
| 426 |
+
# 匹配:非空格字符 + 空格 + 非空格字符(如果其中一個是中文,則移除空格)
|
| 427 |
+
mixed_space_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])'
|
| 428 |
+
text = re.sub(mixed_space_pattern, r'\1\2', text)
|
| 429 |
+
|
| 430 |
+
# 5. 移除多個連續空格(保留單個空格,用於英文單詞之間)
|
| 431 |
+
text = re.sub(r' +', ' ', text)
|
| 432 |
+
|
| 433 |
+
# 6. 清理行首行尾的空格(但保留換行符)
|
| 434 |
+
lines = text.split('\n')
|
| 435 |
+
cleaned_lines = [line.strip() for line in lines]
|
| 436 |
+
text = '\n'.join(cleaned_lines)
|
| 437 |
+
|
| 438 |
+
# 7. 移除多個連續的換行符(保留最多兩個,用於段落分隔)
|
| 439 |
+
text = re.sub(r'\n{3,}', '\n\n', text)
|
| 440 |
+
|
| 441 |
+
# 8. 修復可能的殘留問題:移除中文字符之間殘留的空格
|
| 442 |
+
# 再次檢查並移除中文字符之間的空格(處理可能遺漏的情況)
|
| 443 |
+
text = re.sub(r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])', r'\1\2', text)
|
| 444 |
+
|
| 445 |
+
return text
|
| 446 |
+
|
| 447 |
+
def load_from_file(self, file_path: str) -> Dict:
|
| 448 |
+
"""
|
| 449 |
+
從本地檔案載入文檔(支援 PDF, DOCX, TXT 等)
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
file_path: 檔案路徑
|
| 453 |
+
|
| 454 |
+
Returns:
|
| 455 |
+
文檔字典,包含內容和元數據
|
| 456 |
+
"""
|
| 457 |
+
file_path = Path(file_path)
|
| 458 |
+
|
| 459 |
+
if not file_path.exists():
|
| 460 |
+
raise FileNotFoundError(f"檔案不存在: {file_path}")
|
| 461 |
+
|
| 462 |
+
file_ext = file_path.suffix.lower()
|
| 463 |
+
file_name = file_path.stem
|
| 464 |
+
file_size = os.path.getsize(file_path)
|
| 465 |
+
|
| 466 |
+
# 根據檔案類型選擇不同的加載器
|
| 467 |
+
if file_ext == '.pdf':
|
| 468 |
+
try:
|
| 469 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 470 |
+
loader = PyPDFLoader(str(file_path))
|
| 471 |
+
pages = loader.load()
|
| 472 |
+
# 合併所有頁面
|
| 473 |
+
full_text = "\n\n".join([page.page_content for page in pages])
|
| 474 |
+
# 清理提取的文本(移除多餘空格)
|
| 475 |
+
full_text = self.clean_extracted_text(full_text)
|
| 476 |
+
except ImportError:
|
| 477 |
+
raise ImportError(
|
| 478 |
+
"需要安裝 pypdf 來處理 PDF 檔案: pip install pypdf"
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
elif file_ext in ['.docx', '.doc']:
|
| 482 |
+
try:
|
| 483 |
+
from langchain_community.document_loaders import Docx2txtLoader
|
| 484 |
+
loader = Docx2txtLoader(str(file_path))
|
| 485 |
+
pages = loader.load()
|
| 486 |
+
full_text = "\n\n".join([page.page_content for page in pages])
|
| 487 |
+
# 清理提取的文本(移除多餘空格)
|
| 488 |
+
full_text = self.clean_extracted_text(full_text)
|
| 489 |
+
except ImportError:
|
| 490 |
+
raise ImportError(
|
| 491 |
+
"需要安裝 docx2txt 來處理 DOCX 檔案: pip install docx2txt"
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
elif file_ext == '.txt':
|
| 495 |
+
# 嘗試不同的編碼
|
| 496 |
+
encodings = ['utf-8', 'gbk', 'big5', 'latin-1']
|
| 497 |
+
full_text = None
|
| 498 |
+
for encoding in encodings:
|
| 499 |
+
try:
|
| 500 |
+
with open(file_path, 'r', encoding=encoding) as f:
|
| 501 |
+
full_text = f.read()
|
| 502 |
+
break
|
| 503 |
+
except UnicodeDecodeError:
|
| 504 |
+
continue
|
| 505 |
+
|
| 506 |
+
if full_text is None:
|
| 507 |
+
raise ValueError(f"無法讀取檔案,嘗試的編碼都不適用: {encodings}")
|
| 508 |
+
|
| 509 |
+
else:
|
| 510 |
+
raise ValueError(
|
| 511 |
+
f"不支援的檔案類型: {file_ext}\n"
|
| 512 |
+
f"支援的格式: .pdf, .docx, .doc, .txt"
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if not full_text or len(full_text.strip()) == 0:
|
| 516 |
+
raise ValueError(f"檔案為空或無法提取文字: {file_path}")
|
| 517 |
+
|
| 518 |
+
return {
|
| 519 |
+
"title": file_name,
|
| 520 |
+
"content": full_text,
|
| 521 |
+
"file_path": str(file_path),
|
| 522 |
+
"file_type": file_ext,
|
| 523 |
+
"file_size": file_size,
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
def process_file(self, file_path: str) -> List[Dict]:
|
| 527 |
+
"""
|
| 528 |
+
處理單個檔案,分割成 chunks
|
| 529 |
+
|
| 530 |
+
Args:
|
| 531 |
+
file_path: 檔案路徑
|
| 532 |
+
|
| 533 |
+
Returns:
|
| 534 |
+
處理後的文檔 chunks 列表
|
| 535 |
+
"""
|
| 536 |
+
# 載入檔案
|
| 537 |
+
file_doc = self.load_from_file(file_path)
|
| 538 |
+
|
| 539 |
+
# 分割文字(根據選擇的模式:字符分塊或語義分塊)
|
| 540 |
+
chunks = self.text_splitter.split_text(file_doc["content"])
|
| 541 |
+
|
| 542 |
+
# 後處理:過濾和合併太小的 chunks(僅語義分塊模式)
|
| 543 |
+
chunks = self._post_process_chunks(chunks)
|
| 544 |
+
|
| 545 |
+
if not chunks:
|
| 546 |
+
raise ValueError(f"檔案分割後沒有內容: {file_path}")
|
| 547 |
+
|
| 548 |
+
# 創建文檔 chunks
|
| 549 |
+
documents = []
|
| 550 |
+
for i, chunk in enumerate(chunks):
|
| 551 |
+
doc = {
|
| 552 |
+
"content": chunk,
|
| 553 |
+
"metadata": {
|
| 554 |
+
"title": file_doc["title"],
|
| 555 |
+
"file_path": file_doc["file_path"],
|
| 556 |
+
"file_type": file_doc["file_type"],
|
| 557 |
+
"file_size": file_doc["file_size"],
|
| 558 |
+
"chunk_index": i,
|
| 559 |
+
"total_chunks": len(chunks),
|
| 560 |
+
"chunking_method": "semantic" if self.use_semantic_chunking else "character"
|
| 561 |
+
}
|
| 562 |
+
}
|
| 563 |
+
documents.append(doc)
|
| 564 |
+
|
| 565 |
+
return documents
|
| 566 |
+
|
| 567 |
+
def process_files(self, file_paths: List[str]) -> List[Dict]:
|
| 568 |
+
"""
|
| 569 |
+
處理多個檔案
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
file_paths: 檔案路徑列表
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
所有檔案的文檔 chunks 列表
|
| 576 |
+
"""
|
| 577 |
+
all_documents = []
|
| 578 |
+
for file_path in file_paths:
|
| 579 |
+
try:
|
| 580 |
+
print(f"處理檔案: {file_path}")
|
| 581 |
+
documents = self.process_file(file_path)
|
| 582 |
+
all_documents.extend(documents)
|
| 583 |
+
print(f" ✓ 創建了 {len(documents)} 個 chunks")
|
| 584 |
+
except Exception as e:
|
| 585 |
+
print(f" ✗ 處理檔案失敗: {file_path}")
|
| 586 |
+
print(f" 錯誤: {e}")
|
| 587 |
+
continue
|
| 588 |
+
|
| 589 |
+
return all_documents
|
| 590 |
+
|
src/hybrid_subquery_hyde_rag.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid Sub-query + HyDE RAG:融合 Sub-query Decomposition 和 HyDE
|
| 3 |
+
結合兩種方法的優勢,提升檢索精度
|
| 4 |
+
"""
|
| 5 |
+
from typing import List, Dict, Optional
|
| 6 |
+
from .retrievers.reranker import RAGPipeline
|
| 7 |
+
from .retrievers.vector_retriever import VectorRetriever
|
| 8 |
+
from .prompt_formatter import PromptFormatter
|
| 9 |
+
from .llm_integration import OllamaLLM
|
| 10 |
+
import hashlib
|
| 11 |
+
import time
|
| 12 |
+
import logging
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class HybridSubqueryHyDERAG:
|
| 19 |
+
"""融合 Sub-query Decomposition 和 HyDE 的 RAG 系統"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
rag_pipeline: RAGPipeline,
|
| 24 |
+
vector_retriever: VectorRetriever,
|
| 25 |
+
llm: OllamaLLM,
|
| 26 |
+
max_sub_queries: int = 3,
|
| 27 |
+
top_k_per_subquery: int = 5,
|
| 28 |
+
hypothetical_length: int = 200,
|
| 29 |
+
temperature_subquery: float = 0.3,
|
| 30 |
+
temperature_hyde: float = 0.7,
|
| 31 |
+
enable_parallel: bool = True
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
初始化融合 RAG
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
rag_pipeline: RAG 管線實例
|
| 38 |
+
vector_retriever: 向量檢索器
|
| 39 |
+
llm: LLM 實例
|
| 40 |
+
max_sub_queries: 最多生成的子問題數量
|
| 41 |
+
top_k_per_subquery: 每個子問題檢索的結果數量
|
| 42 |
+
hypothetical_length: 假設性文檔目標長度(字符數)
|
| 43 |
+
temperature_subquery: 生成子問題的溫度(較低,更穩定)
|
| 44 |
+
temperature_hyde: 生成假設性文檔的溫度(較高,更多專業術語)
|
| 45 |
+
enable_parallel: 是否並行處理
|
| 46 |
+
"""
|
| 47 |
+
self.rag_pipeline = rag_pipeline
|
| 48 |
+
self.vector_retriever = vector_retriever
|
| 49 |
+
self.llm = llm
|
| 50 |
+
self.max_sub_queries = max_sub_queries
|
| 51 |
+
self.top_k_per_subquery = top_k_per_subquery
|
| 52 |
+
self.hypothetical_length = hypothetical_length
|
| 53 |
+
self.temperature_subquery = temperature_subquery
|
| 54 |
+
self.temperature_hyde = temperature_hyde
|
| 55 |
+
self.enable_parallel = enable_parallel
|
| 56 |
+
|
| 57 |
+
def _generate_sub_queries(self, question: str) -> List[str]:
|
| 58 |
+
"""
|
| 59 |
+
生成子問題(與 SubQueryDecompositionRAG 相同)
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
question: 原始問題
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
子問題列表
|
| 66 |
+
"""
|
| 67 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 68 |
+
|
| 69 |
+
if is_chinese:
|
| 70 |
+
prompt = f"""你是一個專業助理。請將以下原始問題拆解成最多 {self.max_sub_queries} 個具體的子問題,以便進行資料搜尋。
|
| 71 |
+
每個子問題應專注於原始問題的一個特定面向。請以換行符號分隔問題。
|
| 72 |
+
|
| 73 |
+
原始問題: {question}
|
| 74 |
+
|
| 75 |
+
子問題清單:"""
|
| 76 |
+
else:
|
| 77 |
+
prompt = f"""You are a professional assistant. Please decompose the following original question into at most {self.max_sub_queries} specific sub-questions for information retrieval.
|
| 78 |
+
Each sub-question should focus on a specific aspect of the original question. Please separate questions with newlines.
|
| 79 |
+
|
| 80 |
+
Original question: {question}
|
| 81 |
+
|
| 82 |
+
Sub-question list:"""
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
response = self.llm.generate(
|
| 86 |
+
prompt=prompt,
|
| 87 |
+
temperature=self.temperature_subquery,
|
| 88 |
+
max_tokens=500
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
sub_queries = [
|
| 92 |
+
q.strip()
|
| 93 |
+
for q in response.strip().split("\n")
|
| 94 |
+
if q.strip() and not q.strip().startswith("#")
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
# 移除編號前綴(如 "1. ", "1) " 等)
|
| 98 |
+
cleaned_queries = []
|
| 99 |
+
for q in sub_queries:
|
| 100 |
+
q = q.lstrip("0123456789. )")
|
| 101 |
+
q = q.strip()
|
| 102 |
+
if q:
|
| 103 |
+
cleaned_queries.append(q)
|
| 104 |
+
|
| 105 |
+
cleaned_queries = cleaned_queries[:self.max_sub_queries]
|
| 106 |
+
|
| 107 |
+
if not cleaned_queries:
|
| 108 |
+
logger.warning("⚠️ 未生成子問題,使用原始問題")
|
| 109 |
+
cleaned_queries = [question]
|
| 110 |
+
|
| 111 |
+
return cleaned_queries
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.error(f"⚠️ 生成子問題時出錯: {e}")
|
| 115 |
+
return [question]
|
| 116 |
+
|
| 117 |
+
def _generate_hypothetical_document(self, sub_query: str) -> str:
|
| 118 |
+
"""
|
| 119 |
+
為子問題生成假設性文檔(與 HyDERAG 相同)
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
sub_query: 子問題
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
假設性文檔文本
|
| 126 |
+
"""
|
| 127 |
+
is_chinese = PromptFormatter.detect_language(sub_query) == "zh"
|
| 128 |
+
|
| 129 |
+
if is_chinese:
|
| 130 |
+
prompt = f"""請針對以下問題,寫出一段約 {self.hypothetical_length} 字的專業技術檔案內容。
|
| 131 |
+
這段內容應包含該領域常見的專業術語與原理說明,以便用於後續的語義檢索。
|
| 132 |
+
請使用專業的術語和概念,即使你對某些細節不確定,也要包含相關的專業詞彙。
|
| 133 |
+
|
| 134 |
+
問題: {sub_query}
|
| 135 |
+
|
| 136 |
+
專業技術內容:"""
|
| 137 |
+
else:
|
| 138 |
+
prompt = f"""Please write a professional technical document of approximately {self.hypothetical_length} words in response to the following question.
|
| 139 |
+
This content should include common professional terminology and principle explanations in this field, to be used for subsequent semantic retrieval.
|
| 140 |
+
Please use professional terms and concepts, and include relevant professional vocabulary even if you are uncertain about some details.
|
| 141 |
+
|
| 142 |
+
Question: {sub_query}
|
| 143 |
+
|
| 144 |
+
Professional technical content:"""
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
hypothetical_doc = self.llm.generate(
|
| 148 |
+
prompt=prompt,
|
| 149 |
+
temperature=self.temperature_hyde,
|
| 150 |
+
max_tokens=500
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
hypothetical_doc = hypothetical_doc.strip()
|
| 154 |
+
|
| 155 |
+
if not hypothetical_doc:
|
| 156 |
+
logger.warning(f"⚠️ 子問題 '{sub_query}' 的假設性文檔為空,使用子問題本身")
|
| 157 |
+
return sub_query
|
| 158 |
+
|
| 159 |
+
logger.debug(f"✅ 為子問題生成假設性文檔(長度: {len(hypothetical_doc)} 字符)")
|
| 160 |
+
return hypothetical_doc
|
| 161 |
+
|
| 162 |
+
except Exception as e:
|
| 163 |
+
logger.error(f"⚠️ 生成假設性文檔時出錯: {e}")
|
| 164 |
+
return sub_query
|
| 165 |
+
|
| 166 |
+
def _get_doc_id(self, doc: Dict) -> str:
|
| 167 |
+
"""
|
| 168 |
+
生成文檔的唯一標識符
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
doc: 文檔字典
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
唯一 ID
|
| 175 |
+
"""
|
| 176 |
+
metadata = doc.get("metadata", {})
|
| 177 |
+
content = doc.get("content", "")
|
| 178 |
+
|
| 179 |
+
if "arxiv_id" in metadata and "chunk_index" in metadata:
|
| 180 |
+
return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
|
| 181 |
+
elif "file_path" in metadata and "chunk_index" in metadata:
|
| 182 |
+
return f"{metadata['file_path']}_{metadata['chunk_index']}"
|
| 183 |
+
else:
|
| 184 |
+
content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
|
| 185 |
+
return f"doc_{content_hash}"
|
| 186 |
+
|
| 187 |
+
def _process_subquery_with_hyde(
|
| 188 |
+
self,
|
| 189 |
+
sub_query: str,
|
| 190 |
+
metadata_filter: Optional[Dict] = None
|
| 191 |
+
) -> tuple:
|
| 192 |
+
"""
|
| 193 |
+
處理單個子問題:生成假設性文檔並檢索
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
sub_query: 子問題
|
| 197 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
(檢索結果列表, 假設性文檔)
|
| 201 |
+
"""
|
| 202 |
+
try:
|
| 203 |
+
# 生成假設性文檔
|
| 204 |
+
hypothetical_doc = self._generate_hypothetical_document(sub_query)
|
| 205 |
+
|
| 206 |
+
# 使用假設性文檔檢索
|
| 207 |
+
results = self.vector_retriever.retrieve(
|
| 208 |
+
query=hypothetical_doc, # 使用假設性文檔而不是子問題
|
| 209 |
+
top_k=self.top_k_per_subquery,
|
| 210 |
+
metadata_filter=metadata_filter
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
return results, hypothetical_doc
|
| 214 |
+
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
|
| 217 |
+
return [], ""
|
| 218 |
+
|
| 219 |
+
def query(
|
| 220 |
+
self,
|
| 221 |
+
question: str,
|
| 222 |
+
top_k: int = 5,
|
| 223 |
+
metadata_filter: Optional[Dict] = None,
|
| 224 |
+
return_sub_queries: bool = False,
|
| 225 |
+
return_hypothetical: bool = False
|
| 226 |
+
) -> Dict:
|
| 227 |
+
"""
|
| 228 |
+
執行融合 RAG 檢索(不生成答案)
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
question: 原始問題
|
| 232 |
+
top_k: 返回前 k 個結果
|
| 233 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 234 |
+
return_sub_queries: 是否返回子問題列表
|
| 235 |
+
return_hypothetical: 是否返回假設性文檔字典(子問題 -> 假設性文檔)
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
包含檢索結果和統計資訊的字典
|
| 239 |
+
"""
|
| 240 |
+
start_time = time.time()
|
| 241 |
+
|
| 242 |
+
# 第一步:生成子問題
|
| 243 |
+
logger.info(f"🔍 拆解問題: '{question}'")
|
| 244 |
+
sub_queries = self._generate_sub_queries(question)
|
| 245 |
+
logger.info(f"✅ 生成 {len(sub_queries)} 個子問題")
|
| 246 |
+
|
| 247 |
+
# 第二步:為每個子問題生成假設性文檔並檢索
|
| 248 |
+
logger.info(f"📚 為每個子問題生成假設性文檔並檢索...")
|
| 249 |
+
unique_docs = {}
|
| 250 |
+
hypothetical_docs = {}
|
| 251 |
+
|
| 252 |
+
if self.enable_parallel and len(sub_queries) > 1:
|
| 253 |
+
# 並行處理
|
| 254 |
+
logger.info(f"🔄 並行處理 {len(sub_queries)} 個子問題...")
|
| 255 |
+
with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor:
|
| 256 |
+
future_to_query = {
|
| 257 |
+
executor.submit(self._process_subquery_with_hyde, sq, metadata_filter): sq
|
| 258 |
+
for sq in sub_queries
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
for future in as_completed(future_to_query):
|
| 262 |
+
sub_query = future_to_query[future]
|
| 263 |
+
try:
|
| 264 |
+
results, hypo_doc = future.result()
|
| 265 |
+
hypothetical_docs[sub_query] = hypo_doc
|
| 266 |
+
|
| 267 |
+
logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(results)} 個結果")
|
| 268 |
+
|
| 269 |
+
for doc in results:
|
| 270 |
+
doc_id = self._get_doc_id(doc)
|
| 271 |
+
if doc_id not in unique_docs:
|
| 272 |
+
unique_docs[doc_id] = doc
|
| 273 |
+
else:
|
| 274 |
+
# 保留分數更高的
|
| 275 |
+
existing_score = unique_docs[doc_id].get('score', 0)
|
| 276 |
+
new_score = doc.get('score', 0)
|
| 277 |
+
if new_score > existing_score:
|
| 278 |
+
unique_docs[doc_id] = doc
|
| 279 |
+
except Exception as e:
|
| 280 |
+
logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
|
| 281 |
+
else:
|
| 282 |
+
# 串行處理
|
| 283 |
+
logger.info(f"🔄 串行處理 {len(sub_queries)} 個子問題...")
|
| 284 |
+
for sub_query in sub_queries:
|
| 285 |
+
results, hypo_doc = self._process_subquery_with_hyde(sub_query, metadata_filter)
|
| 286 |
+
hypothetical_docs[sub_query] = hypo_doc
|
| 287 |
+
|
| 288 |
+
logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(results)} 個結果")
|
| 289 |
+
|
| 290 |
+
for doc in results:
|
| 291 |
+
doc_id = self._get_doc_id(doc)
|
| 292 |
+
if doc_id not in unique_docs:
|
| 293 |
+
unique_docs[doc_id] = doc
|
| 294 |
+
else:
|
| 295 |
+
existing_score = unique_docs[doc_id].get('score', 0)
|
| 296 |
+
new_score = doc.get('score', 0)
|
| 297 |
+
if new_score > existing_score:
|
| 298 |
+
unique_docs[doc_id] = doc
|
| 299 |
+
|
| 300 |
+
# 第三步:排序並返回前 top_k
|
| 301 |
+
result_list = list(unique_docs.values())
|
| 302 |
+
result_list.sort(key=lambda x: x.get('score', 0), reverse=True)
|
| 303 |
+
final_results = result_list[:top_k]
|
| 304 |
+
|
| 305 |
+
elapsed_time = time.time() - start_time
|
| 306 |
+
logger.info(f"✅ 找到 {len(final_results)} 個唯一文檔(去重後,總共 {len(result_list)} 個)")
|
| 307 |
+
|
| 308 |
+
return {
|
| 309 |
+
"results": final_results,
|
| 310 |
+
"total_docs_found": len(result_list),
|
| 311 |
+
"sub_queries": sub_queries if return_sub_queries else None,
|
| 312 |
+
"hypothetical_documents": hypothetical_docs if return_hypothetical else None,
|
| 313 |
+
"elapsed_time": elapsed_time
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
def generate_answer(
|
| 317 |
+
self,
|
| 318 |
+
question: str,
|
| 319 |
+
formatter: PromptFormatter,
|
| 320 |
+
top_k: int = 5,
|
| 321 |
+
metadata_filter: Optional[Dict] = None,
|
| 322 |
+
document_type: str = "general",
|
| 323 |
+
return_sub_queries: bool = False,
|
| 324 |
+
return_hypothetical: bool = False
|
| 325 |
+
) -> Dict:
|
| 326 |
+
"""
|
| 327 |
+
完整的融合 RAG 流程:檢索 + 生成答案
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
question: 原始問題
|
| 331 |
+
formatter: Prompt 格式化器
|
| 332 |
+
top_k: 用於生成答案的文檔數量
|
| 333 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 334 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 335 |
+
return_sub_queries: 是否返回子問題列表
|
| 336 |
+
return_hypothetical: 是否返回假設性文檔字典
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
包含檢索結果、生成的答案和統計資訊的字典
|
| 340 |
+
"""
|
| 341 |
+
start_time = time.time()
|
| 342 |
+
|
| 343 |
+
# 檢索
|
| 344 |
+
retrieval_result = self.query(
|
| 345 |
+
question=question,
|
| 346 |
+
top_k=top_k,
|
| 347 |
+
metadata_filter=metadata_filter,
|
| 348 |
+
return_sub_queries=return_sub_queries,
|
| 349 |
+
return_hypothetical=return_hypothetical
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
if not retrieval_result["results"]:
|
| 353 |
+
return {
|
| 354 |
+
**retrieval_result,
|
| 355 |
+
"answer": "抱歉,未找到相關文檔來回答此問題。",
|
| 356 |
+
"formatted_context": None,
|
| 357 |
+
"answer_time": 0.0,
|
| 358 |
+
"total_time": retrieval_result["elapsed_time"]
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
# 格式化上下文
|
| 362 |
+
formatted_context = formatter.format_context(
|
| 363 |
+
retrieval_result["results"],
|
| 364 |
+
document_type=document_type
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# 創建 prompt(使用原始問題)
|
| 368 |
+
prompt = formatter.create_prompt(
|
| 369 |
+
question,
|
| 370 |
+
formatted_context,
|
| 371 |
+
document_type=document_type
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
# 生成回答
|
| 375 |
+
logger.info("🤖 生成回答中...")
|
| 376 |
+
answer_start = time.time()
|
| 377 |
+
try:
|
| 378 |
+
answer = self.llm.generate(
|
| 379 |
+
prompt=prompt,
|
| 380 |
+
temperature=0.7,
|
| 381 |
+
max_tokens=2048
|
| 382 |
+
)
|
| 383 |
+
answer_time = time.time() - answer_start
|
| 384 |
+
logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
|
| 385 |
+
except Exception as e:
|
| 386 |
+
logger.error(f"❌ 生成回答時出錯: {e}")
|
| 387 |
+
answer = f"生成回答時出錯: {e}"
|
| 388 |
+
answer_time = time.time() - answer_start
|
| 389 |
+
|
| 390 |
+
total_time = time.time() - start_time
|
| 391 |
+
|
| 392 |
+
return {
|
| 393 |
+
**retrieval_result,
|
| 394 |
+
"answer": answer,
|
| 395 |
+
"formatted_context": formatted_context,
|
| 396 |
+
"answer_time": answer_time,
|
| 397 |
+
"total_time": total_time
|
| 398 |
+
}
|
| 399 |
+
|
src/hyde_rag.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HyDE (Hypothetical Document Embeddings) RAG:使用假設性文檔改善檢索
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
from .retrievers.reranker import RAGPipeline
|
| 6 |
+
from .retrievers.vector_retriever import VectorRetriever
|
| 7 |
+
from .prompt_formatter import PromptFormatter
|
| 8 |
+
from .llm_integration import OllamaLLM
|
| 9 |
+
import time
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class HyDERAG:
|
| 16 |
+
"""使用 HyDE (Hypothetical Document Embeddings) 的 RAG 系統"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
rag_pipeline: RAGPipeline,
|
| 21 |
+
vector_retriever: VectorRetriever,
|
| 22 |
+
llm: OllamaLLM,
|
| 23 |
+
hypothetical_length: int = 200,
|
| 24 |
+
temperature: float = 0.7
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
初始化 HyDE RAG
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
rag_pipeline: RAG 管線實例(用於最終答案生成)
|
| 31 |
+
vector_retriever: 向量檢索器(用於基於假設性文檔的檢索)
|
| 32 |
+
llm: LLM 實例(用於生成假設性文檔)
|
| 33 |
+
hypothetical_length: 假設性文檔的目標長度(字符數)
|
| 34 |
+
temperature: 生成假設性文檔時的溫度參數(建議 0.7,以獲得更多專業術語)
|
| 35 |
+
"""
|
| 36 |
+
self.rag_pipeline = rag_pipeline
|
| 37 |
+
self.vector_retriever = vector_retriever
|
| 38 |
+
self.llm = llm
|
| 39 |
+
self.hypothetical_length = hypothetical_length
|
| 40 |
+
self.temperature = temperature
|
| 41 |
+
|
| 42 |
+
def _generate_hypothetical_document(self, question: str) -> str:
|
| 43 |
+
"""
|
| 44 |
+
生成假設性文檔(Hypothetical Document)
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
question: 用戶問題
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
假設性文檔文本
|
| 51 |
+
"""
|
| 52 |
+
# 檢測語言
|
| 53 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 54 |
+
|
| 55 |
+
if is_chinese:
|
| 56 |
+
prompt = f"""請針對以下問題,寫出一段約 {self.hypothetical_length} 字的專業技術檔案內容。
|
| 57 |
+
這段內容應包含該領域常見的專業術語與原理說明,以便用於後續的語義檢索。
|
| 58 |
+
請使用專業的術語和概念,即使你對某些細節不確定,也要包含相關的專業詞彙。
|
| 59 |
+
|
| 60 |
+
問題: {question}
|
| 61 |
+
|
| 62 |
+
專業技術內容:"""
|
| 63 |
+
else:
|
| 64 |
+
prompt = f"""Please write a professional technical document of approximately {self.hypothetical_length} words in response to the following question.
|
| 65 |
+
This content should include common professional terminology and principle explanations in this field, to be used for subsequent semantic retrieval.
|
| 66 |
+
Please use professional terms and concepts, and include relevant professional vocabulary even if you are uncertain about some details.
|
| 67 |
+
|
| 68 |
+
Question: {question}
|
| 69 |
+
|
| 70 |
+
Professional technical content:"""
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
hypothetical_doc = self.llm.generate(
|
| 74 |
+
prompt=prompt,
|
| 75 |
+
temperature=self.temperature, # 較高的溫度以獲得更多專業術語
|
| 76 |
+
max_tokens=500
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# 清理輸出
|
| 80 |
+
hypothetical_doc = hypothetical_doc.strip()
|
| 81 |
+
|
| 82 |
+
if not hypothetical_doc:
|
| 83 |
+
logger.warning("⚠️ 生成的假設性文檔為空,使用原始問題")
|
| 84 |
+
return question
|
| 85 |
+
|
| 86 |
+
logger.info(f"✅ 生成假設性文檔(長度: {len(hypothetical_doc)} 字符)")
|
| 87 |
+
return hypothetical_doc
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.error(f"⚠️ 生成假設性文檔時出錯: {e}")
|
| 91 |
+
# 回退到使用原始問題
|
| 92 |
+
return question
|
| 93 |
+
|
| 94 |
+
def query(
|
| 95 |
+
self,
|
| 96 |
+
question: str,
|
| 97 |
+
top_k: int = 5,
|
| 98 |
+
metadata_filter: Optional[Dict] = None,
|
| 99 |
+
return_hypothetical: bool = False
|
| 100 |
+
) -> Dict:
|
| 101 |
+
"""
|
| 102 |
+
執行 HyDE 檢索(不生成答案)
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
question: 原始問題
|
| 106 |
+
top_k: 返回前 k 個結果
|
| 107 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 108 |
+
return_hypothetical: 是否在結果中包含假設性文檔
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
包含檢索結果和統計資訊的字典
|
| 112 |
+
"""
|
| 113 |
+
start_time = time.time()
|
| 114 |
+
|
| 115 |
+
# 第一步:生成假設性文檔
|
| 116 |
+
logger.info(f"🔍 生成假設性文檔: '{question}'")
|
| 117 |
+
hypothetical_doc = self._generate_hypothetical_document(question)
|
| 118 |
+
|
| 119 |
+
# 第二步:使用假設性文檔進行檢索
|
| 120 |
+
logger.info(f"📚 使用假設性文檔進行檢索...")
|
| 121 |
+
results = self.vector_retriever.retrieve(
|
| 122 |
+
query=hypothetical_doc, # 使用假設性文檔而不是原始問題
|
| 123 |
+
top_k=top_k,
|
| 124 |
+
metadata_filter=metadata_filter
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
elapsed_time = time.time() - start_time
|
| 128 |
+
logger.info(f"✅ 找到 {len(results)} 個結果(耗時: {elapsed_time:.2f}s)")
|
| 129 |
+
|
| 130 |
+
result = {
|
| 131 |
+
"results": results,
|
| 132 |
+
"total_docs_found": len(results),
|
| 133 |
+
"hypothetical_document": hypothetical_doc if return_hypothetical else None,
|
| 134 |
+
"elapsed_time": elapsed_time
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
return result
|
| 138 |
+
|
| 139 |
+
def generate_answer(
|
| 140 |
+
self,
|
| 141 |
+
question: str,
|
| 142 |
+
formatter: PromptFormatter,
|
| 143 |
+
top_k: int = 5,
|
| 144 |
+
metadata_filter: Optional[Dict] = None,
|
| 145 |
+
document_type: str = "general",
|
| 146 |
+
return_hypothetical: bool = False
|
| 147 |
+
) -> Dict:
|
| 148 |
+
"""
|
| 149 |
+
完整的 HyDE RAG 流程:生成假設性文檔 -> 檢索 -> 生成答案
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
question: 原始問題
|
| 153 |
+
formatter: Prompt 格式化器
|
| 154 |
+
top_k: 用於生成答案的文檔數量
|
| 155 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 156 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 157 |
+
return_hypothetical: 是否在結果中包含假設性文檔
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
包含檢索結果、生成的答案和統計資訊的字典
|
| 161 |
+
"""
|
| 162 |
+
start_time = time.time()
|
| 163 |
+
|
| 164 |
+
# 第一步:生成假設性文檔
|
| 165 |
+
logger.info(f"🔍 生成假設性文檔: '{question}'")
|
| 166 |
+
hypothetical_start = time.time()
|
| 167 |
+
hypothetical_doc = self._generate_hypothetical_document(question)
|
| 168 |
+
hypothetical_time = time.time() - hypothetical_start
|
| 169 |
+
|
| 170 |
+
# 第二步:使用假設性文檔進行檢索
|
| 171 |
+
logger.info(f"📚 使用假設性文檔進行檢索...")
|
| 172 |
+
retrieval_start = time.time()
|
| 173 |
+
results = self.vector_retriever.retrieve(
|
| 174 |
+
query=hypothetical_doc, # 使用假設性文檔而不是原始問題
|
| 175 |
+
top_k=top_k,
|
| 176 |
+
metadata_filter=metadata_filter
|
| 177 |
+
)
|
| 178 |
+
retrieval_time = time.time() - retrieval_start
|
| 179 |
+
|
| 180 |
+
if not results:
|
| 181 |
+
return {
|
| 182 |
+
"results": [],
|
| 183 |
+
"total_docs_found": 0,
|
| 184 |
+
"hypothetical_document": hypothetical_doc if return_hypothetical else None,
|
| 185 |
+
"elapsed_time": retrieval_time + hypothetical_time,
|
| 186 |
+
"answer": "抱歉,未找到相關文檔來回答此問題。",
|
| 187 |
+
"formatted_context": None,
|
| 188 |
+
"answer_time": 0.0,
|
| 189 |
+
"total_time": retrieval_time + hypothetical_time
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
# 第三步:格式化上下文
|
| 193 |
+
formatted_context = formatter.format_context(
|
| 194 |
+
results,
|
| 195 |
+
document_type=document_type
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# 第四步:創建 prompt(使用原始問題,而不是假設性文檔)
|
| 199 |
+
prompt = formatter.create_prompt(
|
| 200 |
+
question, # 使用原始問題生成答案
|
| 201 |
+
formatted_context,
|
| 202 |
+
document_type=document_type
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# 第五步:生成回答
|
| 206 |
+
logger.info("🤖 生成回答中...")
|
| 207 |
+
answer_start = time.time()
|
| 208 |
+
try:
|
| 209 |
+
answer = self.llm.generate(
|
| 210 |
+
prompt=prompt,
|
| 211 |
+
temperature=0.7,
|
| 212 |
+
max_tokens=2048
|
| 213 |
+
)
|
| 214 |
+
answer_time = time.time() - answer_start
|
| 215 |
+
logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"❌ 生成回答時出錯: {e}")
|
| 218 |
+
answer = f"生成回答時出錯: {e}"
|
| 219 |
+
answer_time = time.time() - answer_start
|
| 220 |
+
|
| 221 |
+
total_time = time.time() - start_time
|
| 222 |
+
|
| 223 |
+
return {
|
| 224 |
+
"results": results,
|
| 225 |
+
"total_docs_found": len(results),
|
| 226 |
+
"hypothetical_document": hypothetical_doc if return_hypothetical else None,
|
| 227 |
+
"elapsed_time": retrieval_time + hypothetical_time,
|
| 228 |
+
"hypothetical_time": hypothetical_time,
|
| 229 |
+
"retrieval_time": retrieval_time,
|
| 230 |
+
"answer": answer,
|
| 231 |
+
"formatted_context": formatted_context,
|
| 232 |
+
"answer_time": answer_time,
|
| 233 |
+
"total_time": total_time
|
| 234 |
+
}
|
| 235 |
+
|
src/llm_integration.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM 集成模組:使用 Ollama 進行本地 LLM 推理
|
| 3 |
+
"""
|
| 4 |
+
from typing import Optional, Dict, List
|
| 5 |
+
import logging
|
| 6 |
+
import requests
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class OllamaLLM:
|
| 13 |
+
"""使用 Ollama 進行本地 LLM 推理"""
|
| 14 |
+
|
| 15 |
+
# 適合 16GB MacBook Air 的模型推薦
|
| 16 |
+
RECOMMENDED_MODELS = {
|
| 17 |
+
"deepseek-r1:7b": {
|
| 18 |
+
"name": "deepseek-r1:7b",
|
| 19 |
+
"description": "DeepSeek R1 7B - 大模型,高質量",
|
| 20 |
+
"memory_required": "~8GB",
|
| 21 |
+
"quality": "優秀"
|
| 22 |
+
},
|
| 23 |
+
"llama3.2:3b": {
|
| 24 |
+
"name": "llama3.2:3b",
|
| 25 |
+
"description": "Meta Llama 3.2 3B - 輕量級,適合 16GB 內存",
|
| 26 |
+
"memory_required": "~4GB",
|
| 27 |
+
"quality": "良好"
|
| 28 |
+
},
|
| 29 |
+
"llama3.2:1b": {
|
| 30 |
+
"name": "llama3.2:1b",
|
| 31 |
+
"description": "Meta Llama 3.2 1B - 極輕量級,快速響應",
|
| 32 |
+
"memory_required": "~2GB",
|
| 33 |
+
"quality": "基礎"
|
| 34 |
+
},
|
| 35 |
+
"phi3:mini": {
|
| 36 |
+
"name": "phi3:mini",
|
| 37 |
+
"description": "Microsoft Phi-3 Mini - 小模型,高質量",
|
| 38 |
+
"memory_required": "~3GB",
|
| 39 |
+
"quality": "良好"
|
| 40 |
+
},
|
| 41 |
+
"gemma:2b": {
|
| 42 |
+
"name": "gemma:2b",
|
| 43 |
+
"description": "Google Gemma 2B - 輕量級,開源",
|
| 44 |
+
"memory_required": "~3GB",
|
| 45 |
+
"quality": "良好"
|
| 46 |
+
},
|
| 47 |
+
"mistral:7b": {
|
| 48 |
+
"name": "mistral:7b",
|
| 49 |
+
"description": "Mistral 7B - 較大但質量高(如果內存足夠)",
|
| 50 |
+
"memory_required": "~8GB",
|
| 51 |
+
"quality": "優秀"
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
model_name: str = "llama3.2:3b",
|
| 58 |
+
base_url: str = "http://localhost:11434",
|
| 59 |
+
timeout: int = 120
|
| 60 |
+
):
|
| 61 |
+
"""
|
| 62 |
+
初始化 Ollama LLM
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
model_name: Ollama 模型名稱(預設: llama3.2:3b)
|
| 66 |
+
base_url: Ollama API 基礎 URL
|
| 67 |
+
timeout: 請求超時時間(秒)
|
| 68 |
+
"""
|
| 69 |
+
self.model_name = model_name
|
| 70 |
+
self.base_url = base_url.rstrip('/')
|
| 71 |
+
self.timeout = timeout
|
| 72 |
+
self.api_url = f"{self.base_url}/api"
|
| 73 |
+
|
| 74 |
+
# 檢查模型是否在推薦列表中
|
| 75 |
+
if model_name not in self.RECOMMENDED_MODELS:
|
| 76 |
+
logger.warning(
|
| 77 |
+
f"⚠️ 模型 '{model_name}' 不在推薦列表中。"
|
| 78 |
+
f"推薦的模型: {', '.join(self.RECOMMENDED_MODELS.keys())}"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
logger.info(f"✅ Ollama LLM 初始化完成 (模型: {model_name})")
|
| 82 |
+
|
| 83 |
+
def _check_ollama_connection(self) -> bool:
|
| 84 |
+
"""
|
| 85 |
+
檢查 Ollama 服務是否可用
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
是否連接成功
|
| 89 |
+
"""
|
| 90 |
+
try:
|
| 91 |
+
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
| 92 |
+
return response.status_code == 200
|
| 93 |
+
except Exception as e:
|
| 94 |
+
logger.error(f"❌ 無法連接到 Ollama: {e}")
|
| 95 |
+
logger.error(f" 請確保 Ollama 正在運行: ollama serve")
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
def _check_model_available(self) -> bool:
|
| 99 |
+
"""
|
| 100 |
+
檢查模型是否已下載
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
模型是否可用
|
| 104 |
+
"""
|
| 105 |
+
try:
|
| 106 |
+
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
| 107 |
+
if response.status_code == 200:
|
| 108 |
+
models = response.json().get('models', [])
|
| 109 |
+
model_names = [m.get('name', '') for m in models]
|
| 110 |
+
return any(self.model_name in name for name in model_names)
|
| 111 |
+
return False
|
| 112 |
+
except Exception as e:
|
| 113 |
+
logger.error(f"❌ 檢查模型時出錯: {e}")
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
def generate(
|
| 117 |
+
self,
|
| 118 |
+
prompt: str,
|
| 119 |
+
temperature: float = 0.7,
|
| 120 |
+
max_tokens: Optional[int] = None,
|
| 121 |
+
stream: bool = False
|
| 122 |
+
) -> str:
|
| 123 |
+
"""
|
| 124 |
+
生成回答
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
prompt: 輸入 prompt
|
| 128 |
+
temperature: 溫度參數(0.0-1.0),控制隨機性
|
| 129 |
+
max_tokens: 最大生成 token 數(None 表示使用模型預設)
|
| 130 |
+
stream: 是否使用流式輸出
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
生成的回答
|
| 134 |
+
"""
|
| 135 |
+
# 檢查連接
|
| 136 |
+
if not self._check_ollama_connection():
|
| 137 |
+
raise ConnectionError(
|
| 138 |
+
f"無法連接到 Ollama 服務 ({self.base_url})\n"
|
| 139 |
+
f"請確保 Ollama 正在運行:\n"
|
| 140 |
+
f" 1. 安裝 Ollama: https://ollama.ai\n"
|
| 141 |
+
f" 2. 啟動服務: ollama serve\n"
|
| 142 |
+
f" 3. 下載模型: ollama pull {self.model_name}"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# 檢查模型
|
| 146 |
+
if not self._check_model_available():
|
| 147 |
+
logger.warning(
|
| 148 |
+
f"⚠️ 模型 '{self.model_name}' 可能未下載。"
|
| 149 |
+
f"請運行: ollama pull {self.model_name}"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# 準備請求參數
|
| 153 |
+
payload = {
|
| 154 |
+
"model": self.model_name,
|
| 155 |
+
"prompt": prompt,
|
| 156 |
+
"stream": stream,
|
| 157 |
+
"options": {
|
| 158 |
+
"temperature": temperature,
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
if max_tokens:
|
| 163 |
+
payload["options"]["num_predict"] = max_tokens
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
# 發送請求
|
| 167 |
+
response = requests.post(
|
| 168 |
+
f"{self.api_url}/generate",
|
| 169 |
+
json=payload,
|
| 170 |
+
timeout=self.timeout,
|
| 171 |
+
stream=stream
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if response.status_code != 200:
|
| 175 |
+
error_msg = response.text
|
| 176 |
+
raise RuntimeError(f"Ollama API 錯誤: {error_msg}")
|
| 177 |
+
|
| 178 |
+
if stream:
|
| 179 |
+
# 流式處理
|
| 180 |
+
full_response = ""
|
| 181 |
+
for line in response.iter_lines():
|
| 182 |
+
if line:
|
| 183 |
+
try:
|
| 184 |
+
data = json.loads(line)
|
| 185 |
+
if 'response' in data:
|
| 186 |
+
chunk = data['response']
|
| 187 |
+
full_response += chunk
|
| 188 |
+
print(chunk, end='', flush=True)
|
| 189 |
+
if data.get('done', False):
|
| 190 |
+
break
|
| 191 |
+
except json.JSONDecodeError:
|
| 192 |
+
continue
|
| 193 |
+
print() # 換行
|
| 194 |
+
return full_response
|
| 195 |
+
else:
|
| 196 |
+
# 非流式處理
|
| 197 |
+
data = response.json()
|
| 198 |
+
return data.get('response', '')
|
| 199 |
+
|
| 200 |
+
except requests.exceptions.Timeout:
|
| 201 |
+
raise TimeoutError(
|
| 202 |
+
f"請求超時({self.timeout}秒)。"
|
| 203 |
+
f"可以嘗試增加 timeout 或使用更小的模型。"
|
| 204 |
+
)
|
| 205 |
+
except requests.exceptions.ConnectionError:
|
| 206 |
+
raise ConnectionError(
|
| 207 |
+
f"無法連接到 Ollama 服務。"
|
| 208 |
+
f"請確保 Ollama 正在運行:ollama serve"
|
| 209 |
+
)
|
| 210 |
+
except Exception as e:
|
| 211 |
+
logger.error(f"❌ 生成回答時出錯: {e}")
|
| 212 |
+
raise
|
| 213 |
+
|
| 214 |
+
def list_available_models(self) -> List[str]:
|
| 215 |
+
"""
|
| 216 |
+
列出本地可用的模型
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
可用模型名稱列表
|
| 220 |
+
"""
|
| 221 |
+
try:
|
| 222 |
+
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
| 223 |
+
if response.status_code == 200:
|
| 224 |
+
models = response.json().get('models', [])
|
| 225 |
+
return [m.get('name', '') for m in models]
|
| 226 |
+
return []
|
| 227 |
+
except Exception as e:
|
| 228 |
+
logger.error(f"❌ 獲取模型列表時出錯: {e}")
|
| 229 |
+
return []
|
| 230 |
+
|
| 231 |
+
@classmethod
|
| 232 |
+
def print_recommended_models(cls):
|
| 233 |
+
"""打印推薦的模型列表"""
|
| 234 |
+
print("\n" + "="*60)
|
| 235 |
+
print("適合 16GB MacBook Air 的 Ollama 模型推薦")
|
| 236 |
+
print("="*60)
|
| 237 |
+
print()
|
| 238 |
+
|
| 239 |
+
for model_key, info in cls.RECOMMENDED_MODELS.items():
|
| 240 |
+
print(f"📦 {info['name']}")
|
| 241 |
+
print(f" 描述: {info['description']}")
|
| 242 |
+
print(f" 內存需求: {info['memory_required']}")
|
| 243 |
+
print(f" 質量: {info['quality']}")
|
| 244 |
+
print(f" 下載命令: ollama pull {info['name']}")
|
| 245 |
+
print()
|
| 246 |
+
|
src/prompt_formatter.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prompt 格式化模組:將檢索結果格式化為 LLM 可讀的上下文
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PromptFormatter:
|
| 9 |
+
"""格式化檢索結果供 LLM 使用"""
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
include_metadata: bool = True,
|
| 14 |
+
format_style: str = "detailed",
|
| 15 |
+
max_context_length: Optional[int] = None,
|
| 16 |
+
auto_detect_language: bool = True
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
初始化 Prompt 格式化器
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
include_metadata: 是否包含來源資訊
|
| 23 |
+
format_style: 格式風格 ("detailed", "simple", "minimal")
|
| 24 |
+
max_context_length: 最大上下文長度(字符數),None 表示不限制
|
| 25 |
+
auto_detect_language: 是否自動檢測語言並相應調整回答語言
|
| 26 |
+
"""
|
| 27 |
+
self.include_metadata = include_metadata
|
| 28 |
+
self.format_style = format_style
|
| 29 |
+
self.max_context_length = max_context_length
|
| 30 |
+
self.auto_detect_language = auto_detect_language
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def detect_language(text: str) -> str:
|
| 34 |
+
"""
|
| 35 |
+
檢測文本的主要語言
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
text: 輸入文本
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
"zh" 表示中文,"en" 表示英文
|
| 42 |
+
"""
|
| 43 |
+
# 檢查是否包含中文字符(CJK 統一表意文字範圍)
|
| 44 |
+
chinese_pattern = re.compile(r'[\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff]')
|
| 45 |
+
chinese_chars = len(chinese_pattern.findall(text))
|
| 46 |
+
|
| 47 |
+
# 計算中文字符比例
|
| 48 |
+
total_chars = len([c for c in text if c.isalnum() or c.isspace()])
|
| 49 |
+
|
| 50 |
+
if total_chars == 0:
|
| 51 |
+
return "en" # 預設英文
|
| 52 |
+
|
| 53 |
+
chinese_ratio = chinese_chars / total_chars if total_chars > 0 else 0
|
| 54 |
+
|
| 55 |
+
# 如果中文字符比例超過 20%,認為是中文
|
| 56 |
+
if chinese_ratio > 0.2:
|
| 57 |
+
return "zh"
|
| 58 |
+
else:
|
| 59 |
+
return "en"
|
| 60 |
+
|
| 61 |
+
def get_system_prompt(self, language: str = "zh", document_type: str = "general") -> str:
|
| 62 |
+
"""
|
| 63 |
+
根據語言和文檔類型獲取系統提示詞
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
language: 語言代碼 ("zh" 或 "en")
|
| 67 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 68 |
+
"paper": 學術論文
|
| 69 |
+
"cv": 履歷/履歷
|
| 70 |
+
"general": 通用文檔(預設)
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
系統提示詞字符串
|
| 74 |
+
"""
|
| 75 |
+
if language == "zh":
|
| 76 |
+
if document_type == "paper":
|
| 77 |
+
return (
|
| 78 |
+
"你是一個專業的 AI 研究助手,專門回答關於機器學習、"
|
| 79 |
+
"深度學習和自然語言處理的問題。\n\n"
|
| 80 |
+
"請基於以下提供的學術論文片段來回答用戶的問題。"
|
| 81 |
+
"每個片段都標註了來源論文的資訊。\n\n"
|
| 82 |
+
"回答要求:\n"
|
| 83 |
+
"1. 基於提供的上下文回答問題\n"
|
| 84 |
+
"2. 如果上下文不足以回答,請明確說明\n"
|
| 85 |
+
"3. 在回答中引用具體的論文來源(使用 arXiv ID)\n"
|
| 86 |
+
"4. 如果不同論文有不同觀點,請分別說明\n"
|
| 87 |
+
"5. 保持回答簡潔、準確、專業\n"
|
| 88 |
+
"6. **重要:請使用與用戶問題相同的語言回答**\n"
|
| 89 |
+
)
|
| 90 |
+
elif document_type == "cv":
|
| 91 |
+
return (
|
| 92 |
+
"你是一個專業的 AI 助手,專門幫助分析和介紹簡歷(CV)內容。\n\n"
|
| 93 |
+
"請基於以下提供的文檔片段來回答用戶的問題。"
|
| 94 |
+
"這些片段來自一份簡歷或履歷表。\n\n"
|
| 95 |
+
"回答要求:\n"
|
| 96 |
+
"1. 基於提供的上下文回答問題\n"
|
| 97 |
+
"2. 如果上下文不足以回答,請明確說明\n"
|
| 98 |
+
"3. 在回答中引用具體的文檔內容\n"
|
| 99 |
+
"4. 保持回答簡潔、準確、專業\n"
|
| 100 |
+
"5. **重要:請使用與用戶問題相同的語言回答**\n"
|
| 101 |
+
"6. **請理解:這些片段就是簡歷的內容,請直接基於這些內容回答問題**\n"
|
| 102 |
+
)
|
| 103 |
+
else: # general
|
| 104 |
+
return (
|
| 105 |
+
"你是一個專業的 AI 助手。\n\n"
|
| 106 |
+
"請基於以下提供的文檔片段來回答用戶的問題。"
|
| 107 |
+
"每個片段都標註了來源資訊。\n\n"
|
| 108 |
+
"回答要求:\n"
|
| 109 |
+
"1. 基於提供的上下文回答問題\n"
|
| 110 |
+
"2. 如果上下文不足以回答,請明確說明\n"
|
| 111 |
+
"3. 在回答中引用具體的文檔內容\n"
|
| 112 |
+
"4. 保持回答簡潔、準確、專業\n"
|
| 113 |
+
"5. **重要:請使用與用戶問題相同的語言回答**\n"
|
| 114 |
+
)
|
| 115 |
+
else: # English
|
| 116 |
+
if document_type == "paper":
|
| 117 |
+
return (
|
| 118 |
+
"You are a professional AI research assistant specializing in "
|
| 119 |
+
"machine learning, deep learning, and natural language processing.\n\n"
|
| 120 |
+
"Please answer the user's question based on the provided academic paper excerpts. "
|
| 121 |
+
"Each excerpt is labeled with source paper information.\n\n"
|
| 122 |
+
"Answer requirements:\n"
|
| 123 |
+
"1. Answer the question based on the provided context\n"
|
| 124 |
+
"2. If the context is insufficient, clearly state so\n"
|
| 125 |
+
"3. Cite specific paper sources in your answer (using arXiv ID)\n"
|
| 126 |
+
"4. If different papers have different viewpoints, explain them separately\n"
|
| 127 |
+
"5. Keep answers concise, accurate, and professional\n"
|
| 128 |
+
"6. **Important: Please answer in the same language as the user's question**\n"
|
| 129 |
+
)
|
| 130 |
+
elif document_type == "cv":
|
| 131 |
+
return (
|
| 132 |
+
"You are a professional AI assistant specializing in analyzing and introducing CV (Curriculum Vitae) content.\n\n"
|
| 133 |
+
"Please answer the user's question based on the provided document excerpts. "
|
| 134 |
+
"These excerpts are from a CV or resume.\n\n"
|
| 135 |
+
"Answer requirements:\n"
|
| 136 |
+
"1. Answer the question based on the provided context\n"
|
| 137 |
+
"2. If the context is insufficient, clearly state so\n"
|
| 138 |
+
"3. Cite specific document content in your answer\n"
|
| 139 |
+
"4. Keep answers concise, accurate, and professional\n"
|
| 140 |
+
"5. **Important: Please answer in the same language as the user's question**\n"
|
| 141 |
+
"6. **Please understand: These excerpts ARE the CV content. Please answer directly based on this content.**\n"
|
| 142 |
+
)
|
| 143 |
+
else: # general
|
| 144 |
+
return (
|
| 145 |
+
"You are a professional AI assistant.\n\n"
|
| 146 |
+
"Please answer the user's question based on the provided document excerpts. "
|
| 147 |
+
"Each excerpt is labeled with source information.\n\n"
|
| 148 |
+
"Answer requirements:\n"
|
| 149 |
+
"1. Answer the question based on the provided context\n"
|
| 150 |
+
"2. If the context is insufficient, clearly state so\n"
|
| 151 |
+
"3. Cite specific document content in your answer\n"
|
| 152 |
+
"4. Keep answers concise, accurate, and professional\n"
|
| 153 |
+
"5. **Important: Please answer in the same language as the user's question**\n"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def format_context(
|
| 157 |
+
self,
|
| 158 |
+
results: List[Dict],
|
| 159 |
+
include_metadata: Optional[bool] = None,
|
| 160 |
+
format_style: Optional[str] = None,
|
| 161 |
+
document_type: str = "general"
|
| 162 |
+
) -> str:
|
| 163 |
+
"""
|
| 164 |
+
格式化檢索結果為 LLM 可讀的上下文
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
results: 檢索結果列表
|
| 168 |
+
include_metadata: 是否包含來源資訊(覆蓋初始化參數)
|
| 169 |
+
format_style: 格式風格(覆蓋初始化參數)
|
| 170 |
+
document_type: 文檔類型 ("paper", "cv", "general"),用於調整格式
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
格式化後的上下文字符串
|
| 174 |
+
"""
|
| 175 |
+
if include_metadata is None:
|
| 176 |
+
include_metadata = self.include_metadata
|
| 177 |
+
if format_style is None:
|
| 178 |
+
format_style = self.format_style
|
| 179 |
+
|
| 180 |
+
if not results:
|
| 181 |
+
# 根據格式風格選擇語言
|
| 182 |
+
if format_style == "detailed" or format_style == "simple":
|
| 183 |
+
return "(未找到相關文檔片段)"
|
| 184 |
+
else:
|
| 185 |
+
return "(No relevant excerpts found)"
|
| 186 |
+
|
| 187 |
+
formatted_parts = []
|
| 188 |
+
|
| 189 |
+
for i, result in enumerate(results, 1):
|
| 190 |
+
content = result.get("content", "")
|
| 191 |
+
metadata = result.get("metadata", {})
|
| 192 |
+
|
| 193 |
+
if not include_metadata:
|
| 194 |
+
# 不包含來源資訊,直接使用內容
|
| 195 |
+
formatted_parts.append(f"{content}\n")
|
| 196 |
+
elif format_style == "detailed":
|
| 197 |
+
# 詳細格式:根據文檔類型調整顯示資訊
|
| 198 |
+
if document_type == "cv":
|
| 199 |
+
# CV 格式:顯示檔案名和路徑
|
| 200 |
+
source_info = (
|
| 201 |
+
f"[來源 {i}]\n"
|
| 202 |
+
f"檔案標題: {metadata.get('title', 'N/A')}\n"
|
| 203 |
+
)
|
| 204 |
+
if 'file_path' in metadata:
|
| 205 |
+
source_info += f"檔案路徑: {metadata.get('file_path', 'N/A')}\n"
|
| 206 |
+
if 'file_type' in metadata:
|
| 207 |
+
source_info += f"檔案類型: {metadata.get('file_type', 'N/A')}\n"
|
| 208 |
+
elif document_type == "paper":
|
| 209 |
+
# 論文格式:顯示論文資訊
|
| 210 |
+
authors = metadata.get('authors', [])
|
| 211 |
+
if isinstance(authors, str):
|
| 212 |
+
authors_str = authors
|
| 213 |
+
elif isinstance(authors, list):
|
| 214 |
+
authors_str = ', '.join(authors[:3]) # 最多顯示 3 個作者
|
| 215 |
+
if len(authors) > 3:
|
| 216 |
+
authors_str += f" 等 {len(authors)} 位作者"
|
| 217 |
+
else:
|
| 218 |
+
authors_str = 'N/A'
|
| 219 |
+
|
| 220 |
+
source_info = (
|
| 221 |
+
f"[來源 {i}]\n"
|
| 222 |
+
f"論文標題: {metadata.get('title', 'N/A')}\n"
|
| 223 |
+
f"arXiv ID: {metadata.get('arxiv_id', 'N/A')}\n"
|
| 224 |
+
f"作者: {authors_str}\n"
|
| 225 |
+
f"發布日期: {metadata.get('published', 'N/A')}\n"
|
| 226 |
+
)
|
| 227 |
+
else:
|
| 228 |
+
# 通用格式:顯示可用的資訊
|
| 229 |
+
source_info = f"[來源 {i}]\n"
|
| 230 |
+
if 'title' in metadata:
|
| 231 |
+
source_info += f"標題: {metadata.get('title', 'N/A')}\n"
|
| 232 |
+
if 'file_path' in metadata:
|
| 233 |
+
source_info += f"檔案: {metadata.get('file_path', 'N/A')}\n"
|
| 234 |
+
if 'arxiv_id' in metadata:
|
| 235 |
+
source_info += f"arXiv ID: {metadata.get('arxiv_id', 'N/A')}\n"
|
| 236 |
+
|
| 237 |
+
# 添加相關性分數(如果有的話)
|
| 238 |
+
rerank_score = result.get('rerank_score')
|
| 239 |
+
hybrid_score = result.get('hybrid_score')
|
| 240 |
+
if rerank_score is not None:
|
| 241 |
+
source_info += f"相關性分數: {rerank_score:.4f}\n"
|
| 242 |
+
elif hybrid_score is not None:
|
| 243 |
+
source_info += f"相關性分數: {hybrid_score:.4f}\n"
|
| 244 |
+
|
| 245 |
+
source_info += f"---\n{content}\n"
|
| 246 |
+
formatted_parts.append(source_info)
|
| 247 |
+
|
| 248 |
+
elif format_style == "simple":
|
| 249 |
+
# 簡單格式:只包含關鍵資訊
|
| 250 |
+
title = metadata.get('title', 'N/A')
|
| 251 |
+
if document_type == "paper" and 'arxiv_id' in metadata:
|
| 252 |
+
arxiv_id = metadata.get('arxiv_id', 'N/A')
|
| 253 |
+
source_info = (
|
| 254 |
+
f"[來源 {i}: {title} "
|
| 255 |
+
f"(arXiv:{arxiv_id})]\n"
|
| 256 |
+
f"{content}\n"
|
| 257 |
+
)
|
| 258 |
+
elif document_type == "cv" and 'file_path' in metadata:
|
| 259 |
+
file_path = metadata.get('file_path', 'N/A')
|
| 260 |
+
source_info = (
|
| 261 |
+
f"[來源 {i}: {title} "
|
| 262 |
+
f"({file_path})]\n"
|
| 263 |
+
f"{content}\n"
|
| 264 |
+
)
|
| 265 |
+
else:
|
| 266 |
+
source_info = (
|
| 267 |
+
f"[來源 {i}: {title}]\n"
|
| 268 |
+
f"{content}\n"
|
| 269 |
+
)
|
| 270 |
+
formatted_parts.append(source_info)
|
| 271 |
+
else: # minimal
|
| 272 |
+
# 最小格式:只標註來源
|
| 273 |
+
if document_type == "paper" and 'arxiv_id' in metadata:
|
| 274 |
+
arxiv_id = metadata.get('arxiv_id', 'N/A')
|
| 275 |
+
source_info = (
|
| 276 |
+
f"[arXiv:{arxiv_id}]\n"
|
| 277 |
+
f"{content}\n"
|
| 278 |
+
)
|
| 279 |
+
elif 'title' in metadata:
|
| 280 |
+
title = metadata.get('title', 'N/A')
|
| 281 |
+
source_info = (
|
| 282 |
+
f"[{title}]\n"
|
| 283 |
+
f"{content}\n"
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
source_info = (
|
| 287 |
+
f"[來源 {i}]\n"
|
| 288 |
+
f"{content}\n"
|
| 289 |
+
)
|
| 290 |
+
formatted_parts.append(source_info)
|
| 291 |
+
|
| 292 |
+
formatted_text = "\n" + "="*60 + "\n".join(formatted_parts)
|
| 293 |
+
|
| 294 |
+
# 如果設置了最大長度,進行截斷
|
| 295 |
+
if self.max_context_length and len(formatted_text) > self.max_context_length:
|
| 296 |
+
# 從後往前截斷,保留格式
|
| 297 |
+
formatted_text = formatted_text[:self.max_context_length]
|
| 298 |
+
# 確保最後一個來源資訊完整
|
| 299 |
+
last_separator = formatted_text.rfind("="*60)
|
| 300 |
+
if last_separator > 0:
|
| 301 |
+
formatted_text = formatted_text[:last_separator] + "\n(內容已截斷...)"
|
| 302 |
+
|
| 303 |
+
return formatted_text
|
| 304 |
+
|
| 305 |
+
def create_prompt(
|
| 306 |
+
self,
|
| 307 |
+
query: str,
|
| 308 |
+
context: str,
|
| 309 |
+
system_prompt: Optional[str] = None,
|
| 310 |
+
document_type: str = "general"
|
| 311 |
+
) -> str:
|
| 312 |
+
"""
|
| 313 |
+
創建完整的 LLM prompt
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
query: 用戶查詢
|
| 317 |
+
context: 格式化後的上下文
|
| 318 |
+
system_prompt: 可選的系統提示詞(如果為 None,會根據語言和文檔類型自動選擇)
|
| 319 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
完整的 prompt 字符串
|
| 323 |
+
"""
|
| 324 |
+
# 自動檢測語言並選擇相應的系統提示詞
|
| 325 |
+
if system_prompt is None and self.auto_detect_language:
|
| 326 |
+
detected_language = self.detect_language(query)
|
| 327 |
+
system_prompt = self.get_system_prompt(detected_language, document_type)
|
| 328 |
+
elif system_prompt is None:
|
| 329 |
+
# 如果禁用自動檢測,使用中文作為預設
|
| 330 |
+
system_prompt = self.get_system_prompt("zh", document_type)
|
| 331 |
+
|
| 332 |
+
# 根據檢測到的語言選擇提示詞格式
|
| 333 |
+
detected_language = self.detect_language(query) if self.auto_detect_language else "zh"
|
| 334 |
+
|
| 335 |
+
# 根據文檔類型選擇不同的提示詞結尾
|
| 336 |
+
if document_type == "paper":
|
| 337 |
+
if detected_language == "zh":
|
| 338 |
+
ending = "## 請基於上述文獻片段回答問題,並在回答中引用具體的論文來源。"
|
| 339 |
+
else:
|
| 340 |
+
ending = "## Please answer the question based on the above document excerpts and cite specific paper sources in your answer."
|
| 341 |
+
else:
|
| 342 |
+
if detected_language == "zh":
|
| 343 |
+
ending = "## 請基於上述文檔片段回答問題,並在回答中引用具體的文檔內容。"
|
| 344 |
+
else:
|
| 345 |
+
ending = "## Please answer the question based on the above document excerpts and cite specific document content in your answer."
|
| 346 |
+
|
| 347 |
+
if detected_language == "zh":
|
| 348 |
+
prompt = f"""{system_prompt}
|
| 349 |
+
|
| 350 |
+
## 相關文檔片段:
|
| 351 |
+
|
| 352 |
+
{context}
|
| 353 |
+
|
| 354 |
+
## 用戶問題:
|
| 355 |
+
|
| 356 |
+
{query}
|
| 357 |
+
|
| 358 |
+
{ending}"""
|
| 359 |
+
else: # English
|
| 360 |
+
prompt = f"""{system_prompt}
|
| 361 |
+
|
| 362 |
+
## Relevant Document Excerpts:
|
| 363 |
+
|
| 364 |
+
{context}
|
| 365 |
+
|
| 366 |
+
## User Question:
|
| 367 |
+
|
| 368 |
+
{query}
|
| 369 |
+
|
| 370 |
+
{ending}"""
|
| 371 |
+
|
| 372 |
+
return prompt
|
| 373 |
+
|
| 374 |
+
def format_for_llm(
|
| 375 |
+
self,
|
| 376 |
+
query: str,
|
| 377 |
+
results: List[Dict],
|
| 378 |
+
system_prompt: Optional[str] = None,
|
| 379 |
+
document_type: str = "general"
|
| 380 |
+
) -> str:
|
| 381 |
+
"""
|
| 382 |
+
一站式方法:格式化檢索結果並創建完整的 prompt
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
query: 用戶查詢
|
| 386 |
+
results: 檢索結果列表
|
| 387 |
+
system_prompt: 可選的系統提示詞
|
| 388 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 389 |
+
|
| 390 |
+
Returns:
|
| 391 |
+
完整的 prompt 字符串
|
| 392 |
+
"""
|
| 393 |
+
context = self.format_context(results, document_type=document_type)
|
| 394 |
+
return self.create_prompt(query, context, system_prompt, document_type)
|
| 395 |
+
|
src/retrievers/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
檢索器模組
|
| 3 |
+
"""
|
| 4 |
+
from .base import BaseRetriever
|
| 5 |
+
from .bm25_retriever import BM25Retriever
|
| 6 |
+
from .vector_retriever import VectorRetriever
|
| 7 |
+
from .hybrid_search import HybridSearch
|
| 8 |
+
from .reranker import Reranker, RAGPipeline
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"BaseRetriever",
|
| 12 |
+
"BM25Retriever",
|
| 13 |
+
"VectorRetriever",
|
| 14 |
+
"HybridSearch",
|
| 15 |
+
"Reranker",
|
| 16 |
+
"RAGPipeline",
|
| 17 |
+
]
|
src/retrievers/base.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
檢索器模組的抽象基類
|
| 3 |
+
"""
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import List, Dict, Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseRetriever(ABC):
|
| 9 |
+
"""檢索器的抽象基類"""
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def retrieve(
|
| 13 |
+
self,
|
| 14 |
+
query: str,
|
| 15 |
+
top_k: int = 5,
|
| 16 |
+
metadata_filter: Optional[Dict] = None
|
| 17 |
+
) -> List[Dict]:
|
| 18 |
+
"""
|
| 19 |
+
檢索相關文檔並返回帶有分數的結果。
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
query: 查詢文字
|
| 23 |
+
top_k: 返回前 k 個結果
|
| 24 |
+
metadata_filter: 可選的 metadata 過濾條件字典。
|
| 25 |
+
例如: {"arxiv_id": "1234.5678"} 或 {"title": "Machine Learning"}
|
| 26 |
+
支援多個條件,所有條件必須同時滿足(AND 邏輯)
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
相關文檔列表,每個文檔字典都應包含 "score" 鍵,
|
| 30 |
+
且分數越高代表越相關。返回的結果會根據 metadata_filter 進行過濾。
|
| 31 |
+
"""
|
| 32 |
+
pass
|
src/retrievers/bm25_retriever.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BM25 檢索器模組
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
from rank_bm25 import BM25Okapi
|
| 6 |
+
import re
|
| 7 |
+
from .base import BaseRetriever
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BM25Retriever(BaseRetriever):
|
| 11 |
+
"""使用 BM25 演算法進行關鍵字檢索"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, documents: List[Dict]):
|
| 14 |
+
"""
|
| 15 |
+
初始化 BM25 檢索器
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
documents: 文檔列表,每個文檔包含 "content" 和 "metadata"
|
| 19 |
+
"""
|
| 20 |
+
self.documents = documents
|
| 21 |
+
self.texts = [doc["content"] for doc in documents]
|
| 22 |
+
|
| 23 |
+
# 對文字進行 tokenization(簡單的分詞)
|
| 24 |
+
tokenized_texts = [self._tokenize(text) for text in self.texts]
|
| 25 |
+
|
| 26 |
+
# 初始化 BM25
|
| 27 |
+
self.bm25 = BM25Okapi(tokenized_texts)
|
| 28 |
+
|
| 29 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 30 |
+
"""
|
| 31 |
+
將文字轉換為 tokens(簡單的實作)
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
text: 輸入文字
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
token 列表
|
| 38 |
+
"""
|
| 39 |
+
# 轉為小寫並分割
|
| 40 |
+
text = text.lower()
|
| 41 |
+
# 使用正則表達式分割(保留字母和數字)
|
| 42 |
+
tokens = re.findall(r'\b\w+\b', text)
|
| 43 |
+
return tokens
|
| 44 |
+
|
| 45 |
+
def retrieve(
|
| 46 |
+
self,
|
| 47 |
+
query: str,
|
| 48 |
+
top_k: int = 5,
|
| 49 |
+
metadata_filter: Optional[Dict] = None
|
| 50 |
+
) -> List[Dict]:
|
| 51 |
+
"""
|
| 52 |
+
檢索相關文檔,支援根據 metadata 進行過濾。
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
query: 查詢文字
|
| 56 |
+
top_k: 返回前 k 個結果
|
| 57 |
+
metadata_filter: 可選的 metadata 過濾條件字典。
|
| 58 |
+
例如: {"arxiv_id": "1234.5678"} 只檢索特定論文的 chunks
|
| 59 |
+
或 {"title": "Machine Learning"} 只檢索特定標題的論文
|
| 60 |
+
支援多個條件,所有條件必須同時滿足(AND 邏輯)
|
| 61 |
+
注意:BM25 的過濾是在檢索後進行的,所以可能會返回少於 top_k 的結果
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
相關文檔列表,每個包含 "content", "metadata", "score"
|
| 65 |
+
結果會根據 metadata_filter 進行過濾
|
| 66 |
+
"""
|
| 67 |
+
# Tokenize 查詢
|
| 68 |
+
tokenized_query = self._tokenize(query)
|
| 69 |
+
|
| 70 |
+
# 計算 BM25 分數
|
| 71 |
+
scores = self.bm25.get_scores(tokenized_query)
|
| 72 |
+
|
| 73 |
+
# 獲取所有結果並排序(先獲取更多結果以應對過濾後可能減少的情況)
|
| 74 |
+
# 如果沒有過濾條件,只需要 top_k 個;如果有過濾條件,需要更多候選結果
|
| 75 |
+
candidate_k = top_k * 3 if metadata_filter else top_k
|
| 76 |
+
|
| 77 |
+
# 獲取候選結果索引(按分數降序排列)
|
| 78 |
+
sorted_indices = sorted(
|
| 79 |
+
range(len(scores)),
|
| 80 |
+
key=lambda i: scores[i],
|
| 81 |
+
reverse=True
|
| 82 |
+
)[:candidate_k]
|
| 83 |
+
|
| 84 |
+
# 構建候選結果
|
| 85 |
+
candidate_results = []
|
| 86 |
+
for idx in sorted_indices:
|
| 87 |
+
candidate_results.append({
|
| 88 |
+
"content": self.documents[idx]["content"],
|
| 89 |
+
"metadata": self.documents[idx]["metadata"],
|
| 90 |
+
"score": float(scores[idx]),
|
| 91 |
+
})
|
| 92 |
+
|
| 93 |
+
# 如果提供了 metadata_filter,則進行過濾
|
| 94 |
+
if metadata_filter:
|
| 95 |
+
filtered_results = []
|
| 96 |
+
for result in candidate_results:
|
| 97 |
+
# 檢查該結果的 metadata 是否滿足所有過濾條件
|
| 98 |
+
metadata = result.get("metadata", {})
|
| 99 |
+
matches_all = True
|
| 100 |
+
|
| 101 |
+
for filter_key, filter_value in metadata_filter.items():
|
| 102 |
+
# 獲取文檔中對應的 metadata 值
|
| 103 |
+
doc_value = metadata.get(filter_key)
|
| 104 |
+
|
| 105 |
+
# 檢查是否匹配
|
| 106 |
+
# 支援精確匹配和部分匹配(如果 filter_value 是字串且 doc_value 也是字串)
|
| 107 |
+
if isinstance(filter_value, str) and isinstance(doc_value, str):
|
| 108 |
+
# 字串匹配:支援精確匹配或包含匹配
|
| 109 |
+
if filter_value.lower() not in doc_value.lower():
|
| 110 |
+
matches_all = False
|
| 111 |
+
break
|
| 112 |
+
else:
|
| 113 |
+
# 其他類型(數字、布林值等)使用精確匹配
|
| 114 |
+
if doc_value != filter_value:
|
| 115 |
+
matches_all = False
|
| 116 |
+
break
|
| 117 |
+
|
| 118 |
+
# 如果所有條件都滿足,則加入結果
|
| 119 |
+
if matches_all:
|
| 120 |
+
filtered_results.append(result)
|
| 121 |
+
|
| 122 |
+
# 返回過濾後的結果(最多 top_k 個)
|
| 123 |
+
return filtered_results[:top_k]
|
| 124 |
+
else:
|
| 125 |
+
# 沒有過濾條件,直接返回候選結果
|
| 126 |
+
return candidate_results
|
| 127 |
+
|
src/retrievers/hybrid_search.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid Search 模組:結合 BM25 和向量檢索
|
| 3 |
+
支援兩種融合方法:加權求和(Weighted Sum)和倒數排名融合(RRF)
|
| 4 |
+
"""
|
| 5 |
+
from typing import List, Dict, Optional, Literal
|
| 6 |
+
from .base import BaseRetriever
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HybridSearch(BaseRetriever):
|
| 11 |
+
"""結合稀疏和密集檢索的混合搜尋"""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
sparse_retriever: BaseRetriever,
|
| 16 |
+
dense_retriever: BaseRetriever,
|
| 17 |
+
sparse_weight: float = 0.4,
|
| 18 |
+
dense_weight: float = 0.6,
|
| 19 |
+
fusion_method: Literal["weighted_sum", "rrf"] = "rrf",
|
| 20 |
+
rrf_k: int = 60,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
初始化 Hybrid Search
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
sparse_retriever: 稀疏檢索器 (例如 BM25)
|
| 27 |
+
dense_retriever: 密集檢索器 (例如向量檢索)
|
| 28 |
+
sparse_weight: 稀疏檢索分數的權重(僅用於 weighted_sum 方法)
|
| 29 |
+
dense_weight: 密集檢索分數的權重(僅用於 weighted_sum 方法)
|
| 30 |
+
fusion_method: 融合方法,可選 "weighted_sum" 或 "rrf"
|
| 31 |
+
- "weighted_sum": 加權求和,需要正規化分數並設置權重
|
| 32 |
+
- "rrf": 倒數排名融合(Reciprocal Rank Fusion),
|
| 33 |
+
不需要分數正規化,對不同分數分佈更魯棒
|
| 34 |
+
rrf_k: RRF 方法中的常數 k,通常設為 60(僅用於 rrf 方法)
|
| 35 |
+
較大的 k 值會讓排名較低的文檔獲得更多權重
|
| 36 |
+
"""
|
| 37 |
+
self.sparse_retriever = sparse_retriever
|
| 38 |
+
self.dense_retriever = dense_retriever
|
| 39 |
+
self.fusion_method = fusion_method
|
| 40 |
+
self.rrf_k = rrf_k
|
| 41 |
+
|
| 42 |
+
# 僅在 weighted_sum 方法中使用權重
|
| 43 |
+
if fusion_method == "weighted_sum":
|
| 44 |
+
self.sparse_weight = sparse_weight
|
| 45 |
+
self.dense_weight = dense_weight
|
| 46 |
+
|
| 47 |
+
# 確保權重總和為 1
|
| 48 |
+
total_weight = sparse_weight + dense_weight
|
| 49 |
+
if abs(total_weight - 1.0) > 1e-6:
|
| 50 |
+
self.sparse_weight = sparse_weight / total_weight
|
| 51 |
+
self.dense_weight = dense_weight / total_weight
|
| 52 |
+
|
| 53 |
+
def _normalize_scores(self, results: List[Dict]) -> List[Dict]:
|
| 54 |
+
"""
|
| 55 |
+
將分數正規化到 [0, 1] 區間。
|
| 56 |
+
僅用於 weighted_sum 方法。
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
results: 檢索結果列表,每個字典包含 'score'
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
帶有正規化分數的結果列表
|
| 63 |
+
"""
|
| 64 |
+
scores = [res.get("score", 0.0) for res in results]
|
| 65 |
+
if not scores:
|
| 66 |
+
return results
|
| 67 |
+
|
| 68 |
+
scores_array = np.array(scores)
|
| 69 |
+
min_score = scores_array.min()
|
| 70 |
+
max_score = scores_array.max()
|
| 71 |
+
|
| 72 |
+
if max_score == min_score:
|
| 73 |
+
# 如果所有分數都相同,將它們設置為 1.0
|
| 74 |
+
normalized_scores = [1.0] * len(scores)
|
| 75 |
+
else:
|
| 76 |
+
normalized_scores = ((scores_array - min_score) / (max_score - min_score)).tolist()
|
| 77 |
+
|
| 78 |
+
for i, res in enumerate(results):
|
| 79 |
+
res["score"] = normalized_scores[i]
|
| 80 |
+
|
| 81 |
+
return results
|
| 82 |
+
|
| 83 |
+
def _get_doc_id(self, doc: Dict) -> str:
|
| 84 |
+
"""
|
| 85 |
+
從文檔中提取唯一標識符
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
doc: 文檔字典
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
文檔的唯一 ID
|
| 92 |
+
"""
|
| 93 |
+
metadata = doc.get("metadata", {})
|
| 94 |
+
return f"{metadata.get('arxiv_id', 'unknown')}_{metadata.get('chunk_index', 0)}"
|
| 95 |
+
|
| 96 |
+
def _apply_rrf(
|
| 97 |
+
self,
|
| 98 |
+
sparse_results: List[Dict],
|
| 99 |
+
dense_results: List[Dict]
|
| 100 |
+
) -> List[Dict]:
|
| 101 |
+
"""
|
| 102 |
+
應用倒數排名融合(Reciprocal Rank Fusion, RRF)方法
|
| 103 |
+
|
| 104 |
+
RRF 公式:RRF(d) = Σ(1 / (k + rank_i(d)))
|
| 105 |
+
其中:
|
| 106 |
+
- d 是文檔
|
| 107 |
+
- rank_i(d) 是文檔在第 i 個檢索結果中的排名(從 1 開始)
|
| 108 |
+
- k 是常數(預設為 60)
|
| 109 |
+
|
| 110 |
+
RRF 的優點:
|
| 111 |
+
1. 不需要分數正規化,對不同分數分佈的檢索器更魯棒
|
| 112 |
+
2. 只依賴排名位置,不依賴分數值
|
| 113 |
+
3. 自動處理分數分佈差異的問題
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
sparse_results: 稀疏檢索結果列表
|
| 117 |
+
dense_results: 密集檢索結果列表
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
融合後的結果列表,按 RRF 分數排序
|
| 121 |
+
"""
|
| 122 |
+
# 建立文檔 ID 到 RRF 分數的映射
|
| 123 |
+
doc_to_rrf_score = {}
|
| 124 |
+
|
| 125 |
+
# 處理稀疏檢索結果(BM25)
|
| 126 |
+
for rank, result in enumerate(sparse_results, start=1):
|
| 127 |
+
doc_id = self._get_doc_id(result)
|
| 128 |
+
if doc_id not in doc_to_rrf_score:
|
| 129 |
+
doc_to_rrf_score[doc_id] = {
|
| 130 |
+
"doc": result,
|
| 131 |
+
"rrf_score": 0.0,
|
| 132 |
+
"sparse_rank": None,
|
| 133 |
+
"dense_rank": None
|
| 134 |
+
}
|
| 135 |
+
# 計算 RRF 貢獻:1 / (k + rank)
|
| 136 |
+
doc_to_rrf_score[doc_id]["rrf_score"] += 1.0 / (self.rrf_k + rank)
|
| 137 |
+
doc_to_rrf_score[doc_id]["sparse_rank"] = rank
|
| 138 |
+
|
| 139 |
+
# 處理密集檢索結果(向量)
|
| 140 |
+
for rank, result in enumerate(dense_results, start=1):
|
| 141 |
+
doc_id = self._get_doc_id(result)
|
| 142 |
+
if doc_id not in doc_to_rrf_score:
|
| 143 |
+
doc_to_rrf_score[doc_id] = {
|
| 144 |
+
"doc": result,
|
| 145 |
+
"rrf_score": 0.0,
|
| 146 |
+
"sparse_rank": None,
|
| 147 |
+
"dense_rank": None
|
| 148 |
+
}
|
| 149 |
+
# 計算 RRF 貢獻:1 / (k + rank)
|
| 150 |
+
doc_to_rrf_score[doc_id]["rrf_score"] += 1.0 / (self.rrf_k + rank)
|
| 151 |
+
doc_to_rrf_score[doc_id]["dense_rank"] = rank
|
| 152 |
+
|
| 153 |
+
# 構建結果列表
|
| 154 |
+
rrf_results = []
|
| 155 |
+
for doc_id, data in doc_to_rrf_score.items():
|
| 156 |
+
result = data["doc"].copy()
|
| 157 |
+
result["hybrid_score"] = data["rrf_score"]
|
| 158 |
+
result["rrf_score"] = data["rrf_score"]
|
| 159 |
+
result["sparse_rank"] = data["sparse_rank"]
|
| 160 |
+
result["dense_rank"] = data["dense_rank"]
|
| 161 |
+
|
| 162 |
+
# 從原始結果中獲取分數以供參考
|
| 163 |
+
if data["sparse_rank"] is not None:
|
| 164 |
+
# 從稀疏檢索結果中獲取原始分數
|
| 165 |
+
for sparse_res in sparse_results:
|
| 166 |
+
if self._get_doc_id(sparse_res) == doc_id:
|
| 167 |
+
result["sparse_score"] = sparse_res.get("score", 0.0)
|
| 168 |
+
break
|
| 169 |
+
else:
|
| 170 |
+
result["sparse_score"] = None
|
| 171 |
+
|
| 172 |
+
if data["dense_rank"] is not None:
|
| 173 |
+
# 從密集檢索結果中獲取原始分數
|
| 174 |
+
for dense_res in dense_results:
|
| 175 |
+
if self._get_doc_id(dense_res) == doc_id:
|
| 176 |
+
result["dense_score"] = dense_res.get("score", 0.0)
|
| 177 |
+
break
|
| 178 |
+
else:
|
| 179 |
+
result["dense_score"] = None
|
| 180 |
+
|
| 181 |
+
rrf_results.append(result)
|
| 182 |
+
|
| 183 |
+
# 按 RRF 分數從高到低排序
|
| 184 |
+
rrf_results.sort(key=lambda x: x["rrf_score"], reverse=True)
|
| 185 |
+
|
| 186 |
+
return rrf_results
|
| 187 |
+
|
| 188 |
+
def _apply_weighted_sum(
|
| 189 |
+
self,
|
| 190 |
+
sparse_results: List[Dict],
|
| 191 |
+
dense_results: List[Dict]
|
| 192 |
+
) -> List[Dict]:
|
| 193 |
+
"""
|
| 194 |
+
應用加權求和(Weighted Sum)方法
|
| 195 |
+
|
| 196 |
+
此方法需要:
|
| 197 |
+
1. 正規化兩組分數到相同範圍
|
| 198 |
+
2. 根據權重進行加權求和
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
sparse_results: 稀疏檢索結果列表
|
| 202 |
+
dense_results: 密集檢索結果列表
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
融合後的結果列表,按混合分數排序
|
| 206 |
+
"""
|
| 207 |
+
# 正規化兩組分數
|
| 208 |
+
normalized_sparse = self._normalize_scores(sparse_results)
|
| 209 |
+
normalized_dense = self._normalize_scores(dense_results)
|
| 210 |
+
|
| 211 |
+
# 結合分數
|
| 212 |
+
doc_to_scores = {}
|
| 213 |
+
|
| 214 |
+
# 處理稀疏檢索結果
|
| 215 |
+
for res in normalized_sparse:
|
| 216 |
+
doc_id = self._get_doc_id(res)
|
| 217 |
+
if doc_id not in doc_to_scores:
|
| 218 |
+
doc_to_scores[doc_id] = {"doc": res, "sparse": 0.0, "dense": 0.0}
|
| 219 |
+
doc_to_scores[doc_id]["sparse"] = res["score"]
|
| 220 |
+
|
| 221 |
+
# 處理密集檢索結果
|
| 222 |
+
for res in normalized_dense:
|
| 223 |
+
doc_id = self._get_doc_id(res)
|
| 224 |
+
if doc_id not in doc_to_scores:
|
| 225 |
+
doc_to_scores[doc_id] = {"doc": res, "sparse": 0.0, "dense": 0.0}
|
| 226 |
+
doc_to_scores[doc_id]["dense"] = res["score"]
|
| 227 |
+
|
| 228 |
+
# 計算混合分數並排序
|
| 229 |
+
hybrid_results = []
|
| 230 |
+
for doc_id, scores in doc_to_scores.items():
|
| 231 |
+
hybrid_score = (
|
| 232 |
+
self.sparse_weight * scores["sparse"] +
|
| 233 |
+
self.dense_weight * scores["dense"]
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
result = scores["doc"].copy()
|
| 237 |
+
result["hybrid_score"] = hybrid_score
|
| 238 |
+
result["sparse_score"] = scores["sparse"]
|
| 239 |
+
result["dense_score"] = scores["dense"]
|
| 240 |
+
hybrid_results.append(result)
|
| 241 |
+
|
| 242 |
+
# 按混合分數從高到低排序
|
| 243 |
+
hybrid_results.sort(key=lambda x: x["hybrid_score"], reverse=True)
|
| 244 |
+
|
| 245 |
+
return hybrid_results
|
| 246 |
+
|
| 247 |
+
def retrieve(
|
| 248 |
+
self,
|
| 249 |
+
query: str,
|
| 250 |
+
top_k: int = 5,
|
| 251 |
+
metadata_filter: Optional[Dict] = None
|
| 252 |
+
) -> List[Dict]:
|
| 253 |
+
"""
|
| 254 |
+
執行混合搜尋,支援根據 metadata 進行過濾
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
query: 查詢文字
|
| 258 |
+
top_k: 返回前 k 個結果
|
| 259 |
+
metadata_filter: 可選的 metadata 過濾條件字典。
|
| 260 |
+
例如: {"arxiv_id": "1234.5678"} 只檢索特定論文的 chunks
|
| 261 |
+
或 {"title": "Machine Learning"} 只檢索特定標題的論文
|
| 262 |
+
支援多個條件,所有條件必須同時滿足(AND 邏輯)
|
| 263 |
+
此過濾條件會傳遞給底層的稀疏和密集檢索器
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
相關文檔列表,每個包含 "content", "metadata", "hybrid_score"
|
| 267 |
+
結果會根據 metadata_filter 進行過濾
|
| 268 |
+
|
| 269 |
+
根據 fusion_method 的不同,返回的結果會包含不同的分數欄位:
|
| 270 |
+
- RRF 方法:包含 "rrf_score", "sparse_rank", "dense_rank"
|
| 271 |
+
- Weighted Sum 方法:包含 "sparse_score", "dense_score"
|
| 272 |
+
"""
|
| 273 |
+
# 1. 從兩個檢索器獲取結果(請求更多結果以確保覆蓋率)
|
| 274 |
+
# 將 metadata_filter 傳遞給底層檢索器
|
| 275 |
+
sparse_results = self.sparse_retriever.retrieve(
|
| 276 |
+
query,
|
| 277 |
+
top_k=top_k * 2,
|
| 278 |
+
metadata_filter=metadata_filter
|
| 279 |
+
)
|
| 280 |
+
dense_results = self.dense_retriever.retrieve(
|
| 281 |
+
query,
|
| 282 |
+
top_k=top_k * 2,
|
| 283 |
+
metadata_filter=metadata_filter
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# 2. 根據選擇的融合方法進行結果融合
|
| 287 |
+
if self.fusion_method == "rrf":
|
| 288 |
+
# 使用 RRF(倒數排名融合)方法
|
| 289 |
+
# RRF 不需要分數正規化,直接基於排名進行融合
|
| 290 |
+
hybrid_results = self._apply_rrf(sparse_results, dense_results)
|
| 291 |
+
else:
|
| 292 |
+
# 使用加權求和方法
|
| 293 |
+
# 需要先正規化分數,然後根據權重進行加權求和
|
| 294 |
+
hybrid_results = self._apply_weighted_sum(sparse_results, dense_results)
|
| 295 |
+
|
| 296 |
+
# 3. 返回前 top_k 個結果
|
| 297 |
+
return hybrid_results[:top_k]
|
| 298 |
+
|
src/retrievers/reranker.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
重排序模組:使用 Cross-Encoder 進行精準重排
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Dict, Optional, Tuple
|
| 5 |
+
from sentence_transformers import CrossEncoder
|
| 6 |
+
import time
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
# 嘗試導入 torch 來檢測可用的設備
|
| 10 |
+
try:
|
| 11 |
+
import torch
|
| 12 |
+
TORCH_AVAILABLE = True
|
| 13 |
+
except ImportError:
|
| 14 |
+
TORCH_AVAILABLE = False
|
| 15 |
+
|
| 16 |
+
# 配置日志
|
| 17 |
+
logging.basicConfig(level=logging.INFO)
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_device() -> str:
|
| 22 |
+
"""
|
| 23 |
+
自動檢測並返回最佳可用的設備
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
設備名稱: 'mps' (macOS GPU), 'cuda' (NVIDIA GPU), 或 'cpu'
|
| 27 |
+
"""
|
| 28 |
+
if not TORCH_AVAILABLE:
|
| 29 |
+
return 'cpu'
|
| 30 |
+
|
| 31 |
+
# 優先順序: MPS (macOS) > CUDA (NVIDIA) > CPU
|
| 32 |
+
if torch.backends.mps.is_available():
|
| 33 |
+
return 'mps'
|
| 34 |
+
elif torch.cuda.is_available():
|
| 35 |
+
return 'cuda'
|
| 36 |
+
else:
|
| 37 |
+
return 'cpu'
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Reranker:
|
| 41 |
+
"""重排序組件:使用 Cross-Encoder 進行精準重排"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
model_name: str = "BAAI/bge-reranker-base",
|
| 46 |
+
device: str = None,
|
| 47 |
+
max_length: int = 512,
|
| 48 |
+
batch_size: int = 32,
|
| 49 |
+
enable_cache: bool = True
|
| 50 |
+
):
|
| 51 |
+
"""
|
| 52 |
+
初始化 Cross-Encoder 模型
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
model_name: Cross-Encoder 模型名稱
|
| 56 |
+
device: 設備名稱 ('cuda', 'cpu', 'mps')
|
| 57 |
+
max_length: 最大 token 長度(模型限制)
|
| 58 |
+
batch_size: 批處理大小,用於優化內存使用
|
| 59 |
+
enable_cache: 是否啟用模型緩存
|
| 60 |
+
"""
|
| 61 |
+
try:
|
| 62 |
+
# 自動檢測設備(如果未指定)
|
| 63 |
+
if device is None:
|
| 64 |
+
device = get_device()
|
| 65 |
+
|
| 66 |
+
device_name_map = {
|
| 67 |
+
'mps': 'MPS (macOS GPU)',
|
| 68 |
+
'cuda': 'CUDA (NVIDIA GPU)',
|
| 69 |
+
'cpu': 'CPU'
|
| 70 |
+
}
|
| 71 |
+
device_display = device_name_map.get(device, device)
|
| 72 |
+
|
| 73 |
+
self.model = CrossEncoder(
|
| 74 |
+
model_name,
|
| 75 |
+
device=device,
|
| 76 |
+
max_length=max_length
|
| 77 |
+
)
|
| 78 |
+
self.max_length = max_length
|
| 79 |
+
self.batch_size = batch_size
|
| 80 |
+
self.model_name = model_name
|
| 81 |
+
logger.info(f"✅ 重排模型 {model_name} 已載入 (device: {device_display})")
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"❌ 模型載入失敗: {e}")
|
| 84 |
+
raise
|
| 85 |
+
|
| 86 |
+
def _truncate_text(self, text: str, max_chars: int = 2000) -> str:
|
| 87 |
+
"""
|
| 88 |
+
截斷過長的文本(粗略估計,避免超過 token 限制)
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
text: 原始文本
|
| 92 |
+
max_chars: 最大字符數(保守估計,約 500 tokens)
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
截斷後的文本
|
| 96 |
+
"""
|
| 97 |
+
if len(text) <= max_chars:
|
| 98 |
+
return text
|
| 99 |
+
# 截斷並添加省略號
|
| 100 |
+
return text[:max_chars - 3] + "..."
|
| 101 |
+
|
| 102 |
+
def _prepare_pairs(
|
| 103 |
+
self,
|
| 104 |
+
query: str,
|
| 105 |
+
documents: List[Dict]
|
| 106 |
+
) -> List[Tuple[str, str]]:
|
| 107 |
+
"""
|
| 108 |
+
準備 (query, document) 配對,處理文本長度
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
query: 查詢文本
|
| 112 |
+
documents: 文檔列表
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
(query, content) 配對列表
|
| 116 |
+
"""
|
| 117 |
+
pairs = []
|
| 118 |
+
truncated_indices = [] # 記錄哪些文檔被截斷了
|
| 119 |
+
|
| 120 |
+
# 粗略估計:每個字符約 0.25 tokens,為 query 預留空間
|
| 121 |
+
max_doc_chars = int((self.max_length * 0.7) - len(query))
|
| 122 |
+
|
| 123 |
+
for i, doc in enumerate(documents):
|
| 124 |
+
content = doc.get("content", "")
|
| 125 |
+
original_length = len(content)
|
| 126 |
+
|
| 127 |
+
# 如果內容過長,進行截斷
|
| 128 |
+
if len(content) > max_doc_chars:
|
| 129 |
+
content = self._truncate_text(content, max_doc_chars)
|
| 130 |
+
truncated_indices.append(i)
|
| 131 |
+
|
| 132 |
+
pairs.append([query, content])
|
| 133 |
+
|
| 134 |
+
if truncated_indices:
|
| 135 |
+
logger.warning(
|
| 136 |
+
f"⚠️ 有 {len(truncated_indices)} 個文檔因過長被截斷 "
|
| 137 |
+
f"(最大長度: {max_doc_chars} 字符)"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
return pairs
|
| 141 |
+
|
| 142 |
+
def rerank(
|
| 143 |
+
self,
|
| 144 |
+
query: str,
|
| 145 |
+
documents: List[Dict],
|
| 146 |
+
top_k: int = 5,
|
| 147 |
+
preserve_original_scores: bool = True
|
| 148 |
+
) -> List[Dict]:
|
| 149 |
+
"""
|
| 150 |
+
執行精準重排
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
query: 查詢文本
|
| 154 |
+
documents: 文檔列表,每個應包含 "content" 和可選的 "hybrid_score"
|
| 155 |
+
top_k: 返回前 k 個結果
|
| 156 |
+
preserve_original_scores: 是否保留原始分數(hybrid_score)
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
重排後的文檔列表,按 rerank_score 降序排列
|
| 160 |
+
"""
|
| 161 |
+
if not documents:
|
| 162 |
+
logger.warning("⚠️ 文檔列表為空,返回空結果")
|
| 163 |
+
return []
|
| 164 |
+
|
| 165 |
+
if not query or not query.strip():
|
| 166 |
+
logger.warning("⚠️ 查詢為空,返回原始文檔順序")
|
| 167 |
+
return documents[:top_k]
|
| 168 |
+
|
| 169 |
+
start_time = time.time()
|
| 170 |
+
logger.info(f"🔄 開始重排 {len(documents)} 個文檔...")
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
# 1. 準備配對
|
| 174 |
+
pairs = self._prepare_pairs(query, documents)
|
| 175 |
+
|
| 176 |
+
# 2. 批處理計算分數(優化內存使用)
|
| 177 |
+
scores = []
|
| 178 |
+
for i in range(0, len(pairs), self.batch_size):
|
| 179 |
+
batch_pairs = pairs[i:i + self.batch_size]
|
| 180 |
+
batch_scores = self.model.predict(batch_pairs)
|
| 181 |
+
scores.extend(batch_scores.tolist() if hasattr(batch_scores, 'tolist') else batch_scores)
|
| 182 |
+
|
| 183 |
+
# 3. 更新文檔分數
|
| 184 |
+
for i, doc in enumerate(documents):
|
| 185 |
+
doc = doc.copy() # 避免修改原始文檔
|
| 186 |
+
doc["rerank_score"] = float(scores[i])
|
| 187 |
+
|
| 188 |
+
# 保留原始分數供參考
|
| 189 |
+
if preserve_original_scores:
|
| 190 |
+
if "hybrid_score" not in doc:
|
| 191 |
+
# 如果沒有 hybrid_score,嘗試使用其他分數
|
| 192 |
+
doc["original_score"] = doc.get("score", 0.0)
|
| 193 |
+
|
| 194 |
+
documents[i] = doc
|
| 195 |
+
|
| 196 |
+
# 4. 根據 rerank_score 重新排序
|
| 197 |
+
reranked_docs = sorted(
|
| 198 |
+
documents,
|
| 199 |
+
key=lambda x: x.get("rerank_score", float('-inf')),
|
| 200 |
+
reverse=True
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# 5. 統計資訊
|
| 204 |
+
elapsed_time = time.time() - start_time
|
| 205 |
+
avg_score = sum(scores) / len(scores) if scores else 0.0
|
| 206 |
+
max_score = max(scores) if scores else 0.0
|
| 207 |
+
min_score = min(scores) if scores else 0.0
|
| 208 |
+
|
| 209 |
+
logger.info(
|
| 210 |
+
f"✅ 重排完成 (耗時: {elapsed_time:.2f}s, "
|
| 211 |
+
f"平均分數: {avg_score:.4f}, "
|
| 212 |
+
f"範圍: [{min_score:.4f}, {max_score:.4f}])"
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
return reranked_docs[:top_k]
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
logger.error(f"❌ 重排過程出錯: {e}")
|
| 219 |
+
# 降級策略:返回原始順序的前 top_k 個
|
| 220 |
+
logger.warning("⚠️ 使用降級策略:返回原始順序")
|
| 221 |
+
return documents[:top_k]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class RAGPipeline:
|
| 225 |
+
"""協調管線:管理完整的 RAG 流程(召回 + 重排)"""
|
| 226 |
+
|
| 227 |
+
def __init__(
|
| 228 |
+
self,
|
| 229 |
+
hybrid_search,
|
| 230 |
+
reranker,
|
| 231 |
+
recall_k: int = 25,
|
| 232 |
+
adaptive_recall: bool = True,
|
| 233 |
+
min_recall_k: int = 10,
|
| 234 |
+
max_recall_k: int = 50
|
| 235 |
+
):
|
| 236 |
+
"""
|
| 237 |
+
初始化 RAG 管線
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
hybrid_search: HybridSearch 實例
|
| 241 |
+
reranker: Reranker 實例
|
| 242 |
+
recall_k: 第一階段召回的數量(預設值)
|
| 243 |
+
adaptive_recall: 是否根據查詢動態調整 recall_k
|
| 244 |
+
min_recall_k: 最小召回數量
|
| 245 |
+
max_recall_k: 最大召回數量
|
| 246 |
+
"""
|
| 247 |
+
self.hybrid_search = hybrid_search
|
| 248 |
+
self.reranker = reranker
|
| 249 |
+
self.base_recall_k = recall_k
|
| 250 |
+
self.adaptive_recall = adaptive_recall
|
| 251 |
+
self.min_recall_k = min_recall_k
|
| 252 |
+
self.max_recall_k = max_recall_k
|
| 253 |
+
|
| 254 |
+
# 性能統計
|
| 255 |
+
self.stats = {
|
| 256 |
+
"total_queries": 0,
|
| 257 |
+
"avg_recall_time": 0.0,
|
| 258 |
+
"avg_rerank_time": 0.0,
|
| 259 |
+
"avg_total_time": 0.0
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
def _calculate_adaptive_recall_k(self, query: str) -> int:
|
| 263 |
+
"""
|
| 264 |
+
根據查詢複雜度動態計算 recall_k
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
query: 查詢文本
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
調整後的 recall_k
|
| 271 |
+
"""
|
| 272 |
+
if not self.adaptive_recall:
|
| 273 |
+
return self.base_recall_k
|
| 274 |
+
|
| 275 |
+
# 簡單啟發式:根據查詢長度和關鍵詞數量調整
|
| 276 |
+
query_length = len(query.split())
|
| 277 |
+
keyword_count = len(set(query.lower().split()))
|
| 278 |
+
|
| 279 |
+
# 複雜查詢需要更多候選
|
| 280 |
+
if query_length > 10 or keyword_count > 5:
|
| 281 |
+
recall_k = min(self.base_recall_k * 2, self.max_recall_k)
|
| 282 |
+
elif query_length < 3:
|
| 283 |
+
recall_k = max(self.base_recall_k // 2, self.min_recall_k)
|
| 284 |
+
else:
|
| 285 |
+
recall_k = self.base_recall_k
|
| 286 |
+
|
| 287 |
+
return recall_k
|
| 288 |
+
|
| 289 |
+
def query(
|
| 290 |
+
self,
|
| 291 |
+
text: str,
|
| 292 |
+
top_k: int = 5,
|
| 293 |
+
metadata_filter: Optional[Dict] = None,
|
| 294 |
+
enable_rerank: bool = True,
|
| 295 |
+
return_stats: bool = False
|
| 296 |
+
) -> List[Dict]:
|
| 297 |
+
"""
|
| 298 |
+
執行完整的搜尋流程
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
text: 查詢文本
|
| 302 |
+
top_k: 最終返回的結果數量
|
| 303 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 304 |
+
enable_rerank: 是否啟用重排序(可選,用於性能測試)
|
| 305 |
+
return_stats: 是否返回���能統計資訊
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
相關文檔列表,如果 return_stats=True,則返回 (results, stats) 元組
|
| 309 |
+
"""
|
| 310 |
+
if not text or not text.strip():
|
| 311 |
+
logger.warning("⚠️ 查詢為空")
|
| 312 |
+
return []
|
| 313 |
+
|
| 314 |
+
total_start = time.time()
|
| 315 |
+
self.stats["total_queries"] += 1
|
| 316 |
+
|
| 317 |
+
# 動態計算 recall_k
|
| 318 |
+
recall_k = self._calculate_adaptive_recall_k(text)
|
| 319 |
+
logger.info(
|
| 320 |
+
f"🔍 搜尋中: '{text[:50]}...' "
|
| 321 |
+
f"(召回階段: {recall_k} 筆, 最終返回: {top_k} 筆)"
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
try:
|
| 325 |
+
# 第一階段:混合搜尋(召回階段)
|
| 326 |
+
recall_start = time.time()
|
| 327 |
+
initial_results = self.hybrid_search.retrieve(
|
| 328 |
+
query=text,
|
| 329 |
+
top_k=recall_k,
|
| 330 |
+
metadata_filter=metadata_filter
|
| 331 |
+
)
|
| 332 |
+
recall_time = time.time() - recall_start
|
| 333 |
+
|
| 334 |
+
if not initial_results:
|
| 335 |
+
logger.warning("⚠️ 召回階段未找到任何結果")
|
| 336 |
+
return []
|
| 337 |
+
|
| 338 |
+
logger.info(
|
| 339 |
+
f"✅ 召回階段完成: 找到 {len(initial_results)} 個候選 "
|
| 340 |
+
f"(耗時: {recall_time:.2f}s)"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# 第二階段:重排序(精篩階段)
|
| 344 |
+
if enable_rerank and len(initial_results) > top_k:
|
| 345 |
+
rerank_start = time.time()
|
| 346 |
+
final_results = self.reranker.rerank(
|
| 347 |
+
query=text,
|
| 348 |
+
documents=initial_results,
|
| 349 |
+
top_k=top_k
|
| 350 |
+
)
|
| 351 |
+
rerank_time = time.time() - rerank_start
|
| 352 |
+
|
| 353 |
+
logger.info(
|
| 354 |
+
f"✅ 重排階段完成: 從 {len(initial_results)} 個候選中選出 "
|
| 355 |
+
f"{len(final_results)} 個結果 (耗時: {rerank_time:.2f}s)"
|
| 356 |
+
)
|
| 357 |
+
else:
|
| 358 |
+
# 跳過重排序(用於性能測試或候選數較少時)
|
| 359 |
+
final_results = initial_results[:top_k]
|
| 360 |
+
rerank_time = 0.0
|
| 361 |
+
logger.info("⏭️ 跳過重排序階段(候選數不足或已禁用)")
|
| 362 |
+
|
| 363 |
+
# 更新統計資訊
|
| 364 |
+
total_time = time.time() - total_start
|
| 365 |
+
self._update_stats(recall_time, rerank_time, total_time)
|
| 366 |
+
|
| 367 |
+
# 添加性能資訊到結果(可選)
|
| 368 |
+
if return_stats:
|
| 369 |
+
stats = {
|
| 370 |
+
"recall_time": recall_time,
|
| 371 |
+
"rerank_time": rerank_time,
|
| 372 |
+
"total_time": total_time,
|
| 373 |
+
"recall_k": recall_k,
|
| 374 |
+
"candidates_found": len(initial_results),
|
| 375 |
+
"final_results": len(final_results)
|
| 376 |
+
}
|
| 377 |
+
return final_results, stats
|
| 378 |
+
|
| 379 |
+
return final_results
|
| 380 |
+
|
| 381 |
+
except Exception as e:
|
| 382 |
+
logger.error(f"❌ 查詢過程出錯: {e}")
|
| 383 |
+
# 降級策略:嘗試只使用召回階段
|
| 384 |
+
try:
|
| 385 |
+
logger.warning("⚠️ 嘗試降級策略:僅使用召回結果")
|
| 386 |
+
return self.hybrid_search.retrieve(text, top_k=top_k, metadata_filter=metadata_filter)
|
| 387 |
+
except Exception as e2:
|
| 388 |
+
logger.error(f"❌ 降級策略也失敗: {e2}")
|
| 389 |
+
return []
|
| 390 |
+
|
| 391 |
+
def _update_stats(self, recall_time: float, rerank_time: float, total_time: float):
|
| 392 |
+
"""更新性能統計資訊"""
|
| 393 |
+
n = self.stats["total_queries"]
|
| 394 |
+
self.stats["avg_recall_time"] = (
|
| 395 |
+
(self.stats["avg_recall_time"] * (n - 1) + recall_time) / n
|
| 396 |
+
)
|
| 397 |
+
self.stats["avg_rerank_time"] = (
|
| 398 |
+
(self.stats["avg_rerank_time"] * (n - 1) + rerank_time) / n
|
| 399 |
+
)
|
| 400 |
+
self.stats["avg_total_time"] = (
|
| 401 |
+
(self.stats["avg_total_time"] * (n - 1) + total_time) / n
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
def get_stats(self) -> Dict:
|
| 405 |
+
"""獲取性能統計資訊"""
|
| 406 |
+
return self.stats.copy()
|
| 407 |
+
|
| 408 |
+
def reset_stats(self):
|
| 409 |
+
"""重置統計資訊"""
|
| 410 |
+
self.stats = {
|
| 411 |
+
"total_queries": 0,
|
| 412 |
+
"avg_recall_time": 0.0,
|
| 413 |
+
"avg_rerank_time": 0.0,
|
| 414 |
+
"avg_total_time": 0.0
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
def format_results_for_llm(
|
| 418 |
+
self,
|
| 419 |
+
results: List[Dict],
|
| 420 |
+
format_style: str = "detailed"
|
| 421 |
+
) -> str:
|
| 422 |
+
"""
|
| 423 |
+
格式化檢索結果供 LLM 使用(需要導入 PromptFormatter)
|
| 424 |
+
|
| 425 |
+
Args:
|
| 426 |
+
results: 檢索結果列表
|
| 427 |
+
format_style: 格式風格 ("detailed", "simple", "minimal")
|
| 428 |
+
|
| 429 |
+
Returns:
|
| 430 |
+
格式化後的上下文字符串
|
| 431 |
+
"""
|
| 432 |
+
try:
|
| 433 |
+
from ..prompt_formatter import PromptFormatter
|
| 434 |
+
formatter = PromptFormatter(format_style=format_style)
|
| 435 |
+
return formatter.format_context(results)
|
| 436 |
+
except ImportError:
|
| 437 |
+
# 如果無法導入,使用簡單格式
|
| 438 |
+
formatted_parts = []
|
| 439 |
+
for i, result in enumerate(results, 1):
|
| 440 |
+
metadata = result.get("metadata", {})
|
| 441 |
+
content = result.get("content", "")
|
| 442 |
+
arxiv_id = metadata.get('arxiv_id', 'N/A')
|
| 443 |
+
title = metadata.get('title', 'N/A')
|
| 444 |
+
formatted_parts.append(
|
| 445 |
+
f"[來源 {i}: {title} (arXiv:{arxiv_id})]\n{content}\n"
|
| 446 |
+
)
|
| 447 |
+
return "\n" + "="*60 + "\n".join(formatted_parts)
|
| 448 |
+
|
src/retrievers/vector_retriever.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
向量檢索器模組:使用 embedding 和向量資料庫進行語義檢索
|
| 3 |
+
|
| 4 |
+
支援兩種初始化方式:
|
| 5 |
+
1. 自動初始化 embeddings(預設):根據參數創建新的 embedding 模型
|
| 6 |
+
2. 使用外部 embeddings:接收已初始化的 embedding 模型(可與 DocumentProcessor 共用)
|
| 7 |
+
"""
|
| 8 |
+
from typing import List, Dict, Optional, Any
|
| 9 |
+
from langchain_community.vectorstores import Chroma
|
| 10 |
+
from langchain_core.documents import Document
|
| 11 |
+
import os
|
| 12 |
+
from .base import BaseRetriever
|
| 13 |
+
|
| 14 |
+
# 嘗試導入 HuggingFaceEmbeddings(免費模型)
|
| 15 |
+
try:
|
| 16 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 17 |
+
except ImportError:
|
| 18 |
+
try:
|
| 19 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 20 |
+
except ImportError:
|
| 21 |
+
raise ImportError("需要安裝 langchain-community 或 langchain-huggingface 才能使用 Hugging Face embeddings")
|
| 22 |
+
|
| 23 |
+
# 導入 torch 來檢測可用的設備
|
| 24 |
+
try:
|
| 25 |
+
import torch
|
| 26 |
+
TORCH_AVAILABLE = True
|
| 27 |
+
except ImportError:
|
| 28 |
+
TORCH_AVAILABLE = False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_device() -> str:
|
| 32 |
+
"""
|
| 33 |
+
自動檢測並返回最佳可用的設備
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
設備名稱: 'mps' (macOS GPU), 'cuda' (NVIDIA GPU), 或 'cpu'
|
| 37 |
+
"""
|
| 38 |
+
if not TORCH_AVAILABLE:
|
| 39 |
+
return 'cpu'
|
| 40 |
+
|
| 41 |
+
# 優先順序: MPS (macOS) > CUDA (NVIDIA) > CPU
|
| 42 |
+
if torch.backends.mps.is_available():
|
| 43 |
+
return 'mps'
|
| 44 |
+
elif torch.cuda.is_available():
|
| 45 |
+
return 'cuda'
|
| 46 |
+
else:
|
| 47 |
+
return 'cpu'
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class VectorRetriever(BaseRetriever):
|
| 51 |
+
"""使用向量檢索進行語義搜尋"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
documents: List[Dict],
|
| 56 |
+
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 57 |
+
persist_directory: Optional[str] = "./chroma_db",
|
| 58 |
+
hf_cache_dir: Optional[str] = None,
|
| 59 |
+
device: Optional[str] = None,
|
| 60 |
+
embeddings: Optional[Any] = None # 可選:外部傳入的 embedding 模型(優先使用)
|
| 61 |
+
):
|
| 62 |
+
"""
|
| 63 |
+
初始化向量檢索器(使用 Hugging Face embeddings)
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
documents: 文檔列表,每個文檔包含 "content" 和 "metadata"
|
| 67 |
+
embedding_model: Hugging Face embedding 模型名稱(預設: "sentence-transformers/all-MiniLM-L6-v2")
|
| 68 |
+
僅在 embeddings=None 時使用
|
| 69 |
+
persist_directory: Chroma 資料庫持久化目錄
|
| 70 |
+
hf_cache_dir: Hugging Face 模型緩存目錄(例如外接硬碟路徑)
|
| 71 |
+
如果為 None,則使用環境變數 HF_HOME 或預設位置 ~/.cache/huggingface/
|
| 72 |
+
僅在 embeddings=None 時使用
|
| 73 |
+
device: 設備名稱 ('mps', 'cuda', 'cpu'),如果為 None 則自動檢測最佳設備
|
| 74 |
+
僅在 embeddings=None 時使用
|
| 75 |
+
embeddings: 可選的外部 embedding 模型物件
|
| 76 |
+
如果提供,將優先使用此模型,忽略其他參數(embedding_model, hf_cache_dir, device)
|
| 77 |
+
這允許與 DocumentProcessor 共用同一個 embedding 模型實例
|
| 78 |
+
優點:
|
| 79 |
+
- 節省內存(只加載一次模型)
|
| 80 |
+
- 節省時間(避免重複初始化)
|
| 81 |
+
- 確保一致性(分塊和檢索使用相同的模型)
|
| 82 |
+
"""
|
| 83 |
+
# 優先使用傳入的共用模型
|
| 84 |
+
if embeddings is not None:
|
| 85 |
+
self.embeddings = embeddings
|
| 86 |
+
print("✓ 使用外部傳入的 embeddings 模型(與 DocumentProcessor 共用)")
|
| 87 |
+
else:
|
| 88 |
+
# 若無傳入,則執行原有的初始化邏輯
|
| 89 |
+
print(f"使用 Hugging Face embedding 模型: {embedding_model}")
|
| 90 |
+
|
| 91 |
+
# 設置 Hugging Face 緩存目錄
|
| 92 |
+
if hf_cache_dir:
|
| 93 |
+
# 如果指定了緩存目錄,設置環境變數
|
| 94 |
+
os.environ['HF_HOME'] = hf_cache_dir
|
| 95 |
+
os.environ['TRANSFORMERS_CACHE'] = hf_cache_dir
|
| 96 |
+
print(f"模型將存儲在: {hf_cache_dir}")
|
| 97 |
+
else:
|
| 98 |
+
# 檢查是否已經設置了環境變數
|
| 99 |
+
default_cache = os.path.expanduser("~/.cache/huggingface")
|
| 100 |
+
current_cache = os.getenv('HF_HOME', default_cache)
|
| 101 |
+
print(f"模型緩存位置: {current_cache}")
|
| 102 |
+
print("提示: 可以通過設置 hf_cache_dir 參數或環境變數 HF_HOME 來指定外接硬碟路徑")
|
| 103 |
+
|
| 104 |
+
# 自動檢測或使用指定的設備
|
| 105 |
+
if device is None:
|
| 106 |
+
device = get_device()
|
| 107 |
+
|
| 108 |
+
device_name_map = {
|
| 109 |
+
'mps': 'MPS (macOS GPU)',
|
| 110 |
+
'cuda': 'CUDA (NVIDIA GPU)',
|
| 111 |
+
'cpu': 'CPU'
|
| 112 |
+
}
|
| 113 |
+
print(f"使用設備: {device_name_map.get(device, device)}")
|
| 114 |
+
print("首次使用時會下載模型,請稍候...")
|
| 115 |
+
|
| 116 |
+
# 構建 model_kwargs,包含緩存目錄和設備
|
| 117 |
+
model_kwargs = {'device': device}
|
| 118 |
+
if hf_cache_dir:
|
| 119 |
+
model_kwargs['cache_dir'] = hf_cache_dir
|
| 120 |
+
|
| 121 |
+
self.embeddings = HuggingFaceEmbeddings(
|
| 122 |
+
model_name=embedding_model,
|
| 123 |
+
model_kwargs=model_kwargs,
|
| 124 |
+
encode_kwargs={'normalize_embeddings': True} # 正規化 embeddings 以提升效果
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# 將文檔轉換為 LangChain Document 格式
|
| 128 |
+
# 需要將 metadata 中的列表轉換為字串,因為 ChromaDB 不接受列表類型
|
| 129 |
+
def sanitize_metadata(metadata: Dict) -> Dict:
|
| 130 |
+
"""將 metadata 中的列表轉換為字串,以符合 ChromaDB 的要求"""
|
| 131 |
+
sanitized = {}
|
| 132 |
+
for key, value in metadata.items():
|
| 133 |
+
if isinstance(value, list):
|
| 134 |
+
# 將列表轉換為逗號分隔的字串
|
| 135 |
+
sanitized[key] = ", ".join(str(v) for v in value)
|
| 136 |
+
elif isinstance(value, (dict, set)):
|
| 137 |
+
# 將字典或集合轉換為字串
|
| 138 |
+
sanitized[key] = str(value)
|
| 139 |
+
else:
|
| 140 |
+
# 其他類型(str, int, float, bool, None)直接保留
|
| 141 |
+
sanitized[key] = value
|
| 142 |
+
return sanitized
|
| 143 |
+
|
| 144 |
+
langchain_docs = [
|
| 145 |
+
Document(
|
| 146 |
+
page_content=doc["content"],
|
| 147 |
+
metadata=sanitize_metadata(doc["metadata"])
|
| 148 |
+
)
|
| 149 |
+
for doc in documents
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
# 創建向量資料庫
|
| 153 |
+
self.vectorstore = Chroma.from_documents(
|
| 154 |
+
documents=langchain_docs,
|
| 155 |
+
embedding=self.embeddings,
|
| 156 |
+
persist_directory=persist_directory
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# 創建 retriever
|
| 160 |
+
self.retriever = self.vectorstore.as_retriever()
|
| 161 |
+
|
| 162 |
+
def retrieve(
|
| 163 |
+
self,
|
| 164 |
+
query: str,
|
| 165 |
+
top_k: int = 5,
|
| 166 |
+
metadata_filter: Optional[Dict] = None
|
| 167 |
+
) -> List[Dict]:
|
| 168 |
+
"""
|
| 169 |
+
檢索相關文檔,並返回標準化的相似度分數(越高越好)。
|
| 170 |
+
支援根據 metadata 進行過濾。
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
query: 查詢文字
|
| 174 |
+
top_k: 返回前 k 個結果
|
| 175 |
+
metadata_filter: 可選的 metadata 過濾條件字典。
|
| 176 |
+
例如: {"arxiv_id": "1234.5678"} 只檢索特定論文的 chunks
|
| 177 |
+
或 {"title": "Machine Learning"} 只檢索特定標題的論文
|
| 178 |
+
支援多個條件,所有條件必須同時滿足(AND 邏輯)
|
| 179 |
+
注意:ChromaDB 的 where 條件支援精確匹配,不支援部分匹配
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
相關文檔列表,每個包含 "content", "metadata", 和 "score"
|
| 183 |
+
結果會根據 metadata_filter 進行過濾
|
| 184 |
+
"""
|
| 185 |
+
# 構建過濾條件
|
| 186 |
+
# 如果提供了 metadata_filter,先獲取更多結果,然後在 Python 中進行過濾
|
| 187 |
+
# 這是因為 LangChain ChromaDB 的 similarity_search_with_score 方法
|
| 188 |
+
# 對 filter 參數的支援可能因版本而異
|
| 189 |
+
if metadata_filter:
|
| 190 |
+
# 獲取更多結果以確保有足夠的候選進行過濾
|
| 191 |
+
results_with_scores = self.vectorstore.similarity_search_with_score(
|
| 192 |
+
query,
|
| 193 |
+
k=top_k * 10 # 獲取更多結果
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# 在 Python 中進行過濾
|
| 197 |
+
filtered_results = []
|
| 198 |
+
for doc, distance_score in results_with_scores:
|
| 199 |
+
metadata = doc.metadata
|
| 200 |
+
matches = True
|
| 201 |
+
|
| 202 |
+
for key, value in metadata_filter.items():
|
| 203 |
+
doc_value = metadata.get(key)
|
| 204 |
+
|
| 205 |
+
# 檢查是否匹配
|
| 206 |
+
if isinstance(value, dict):
|
| 207 |
+
# 支援運算符格式(例如 {"$eq": "value"})
|
| 208 |
+
if "$eq" in value:
|
| 209 |
+
if doc_value != value["$eq"]:
|
| 210 |
+
matches = False
|
| 211 |
+
break
|
| 212 |
+
else:
|
| 213 |
+
# 其他運算符可以在此擴展
|
| 214 |
+
matches = False
|
| 215 |
+
break
|
| 216 |
+
elif isinstance(value, str) and isinstance(doc_value, str):
|
| 217 |
+
# 字串匹配:支援部分匹配(包含)
|
| 218 |
+
if value.lower() not in doc_value.lower():
|
| 219 |
+
matches = False
|
| 220 |
+
break
|
| 221 |
+
else:
|
| 222 |
+
# 其他類型使用精確匹配
|
| 223 |
+
if doc_value != value:
|
| 224 |
+
matches = False
|
| 225 |
+
break
|
| 226 |
+
|
| 227 |
+
if matches:
|
| 228 |
+
filtered_results.append((doc, distance_score))
|
| 229 |
+
|
| 230 |
+
# 只保留前 top_k 個結果
|
| 231 |
+
results_with_scores = filtered_results[:top_k]
|
| 232 |
+
else:
|
| 233 |
+
# 沒有過濾條件,直接獲取結果
|
| 234 |
+
results_with_scores = self.vectorstore.similarity_search_with_score(
|
| 235 |
+
query,
|
| 236 |
+
k=top_k
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# 構建結果並轉換分數
|
| 240 |
+
results = []
|
| 241 |
+
for doc, distance_score in results_with_scores:
|
| 242 |
+
# 因為 embedding 已正規化,L2 距離的平方為 2 - 2 * cos_sim
|
| 243 |
+
# -> cos_sim = 1 - (distance^2 / 2)
|
| 244 |
+
# 分數範圍在 [0, 1] 之間,越高越相似
|
| 245 |
+
similarity_score = 1 - (distance_score**2 / 2)
|
| 246 |
+
|
| 247 |
+
results.append({
|
| 248 |
+
"content": doc.page_content,
|
| 249 |
+
"metadata": doc.metadata,
|
| 250 |
+
"score": float(similarity_score),
|
| 251 |
+
})
|
| 252 |
+
|
| 253 |
+
return results
|
| 254 |
+
|
src/step_back_rag.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Step-back Prompting 雙軌 RAG:結合具體事實與抽象原理
|
| 3 |
+
使用 Step-back Prompting 技術,同時檢索具體事實和抽象原理,提升回答質量
|
| 4 |
+
"""
|
| 5 |
+
from typing import List, Dict, Optional
|
| 6 |
+
from .retrievers.reranker import RAGPipeline
|
| 7 |
+
from .retrievers.vector_retriever import VectorRetriever
|
| 8 |
+
from .prompt_formatter import PromptFormatter
|
| 9 |
+
from .llm_integration import OllamaLLM
|
| 10 |
+
import time
|
| 11 |
+
import logging
|
| 12 |
+
import hashlib
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class StepBackRAG:
|
| 19 |
+
"""使用 Step-back Prompting 的雙軌 RAG 系統"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
rag_pipeline: RAGPipeline,
|
| 24 |
+
vector_retriever: VectorRetriever,
|
| 25 |
+
llm: OllamaLLM,
|
| 26 |
+
step_back_temperature: float = 0.3, # 生成抽象問題時使用較低溫度
|
| 27 |
+
answer_temperature: float = 0.7,
|
| 28 |
+
enable_parallel: bool = True
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
初始化 Step-back RAG
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
rag_pipeline: RAG 管線實例(用於最終答案生成)
|
| 35 |
+
vector_retriever: 向量檢索器
|
| 36 |
+
llm: LLM 實例
|
| 37 |
+
step_back_temperature: 生成抽象問題的溫度(較低,更穩定)
|
| 38 |
+
answer_temperature: 生成答案的溫度
|
| 39 |
+
enable_parallel: 是否並行執行雙軌檢索
|
| 40 |
+
"""
|
| 41 |
+
self.rag_pipeline = rag_pipeline
|
| 42 |
+
self.vector_retriever = vector_retriever
|
| 43 |
+
self.llm = llm
|
| 44 |
+
self.step_back_temperature = step_back_temperature
|
| 45 |
+
self.answer_temperature = answer_temperature
|
| 46 |
+
self.enable_parallel = enable_parallel
|
| 47 |
+
|
| 48 |
+
def _generate_step_back_question(self, question: str) -> str:
|
| 49 |
+
"""
|
| 50 |
+
生成 Step-back 抽象問題
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
question: 原始具體問題
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
抽象問題
|
| 57 |
+
"""
|
| 58 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 59 |
+
|
| 60 |
+
if is_chinese:
|
| 61 |
+
prompt = f"""你是一個資深專家。請將以下具體問題轉換為一個更抽象、更基礎的原理性問題。
|
| 62 |
+
這個抽象問題應該幫助理解該領域的基礎概念和原理,而不是直接回答具體問題。
|
| 63 |
+
|
| 64 |
+
具體問題: {question}
|
| 65 |
+
|
| 66 |
+
請生成一個抽象問題,用於檢索相關的原理和背景知識:
|
| 67 |
+
"""
|
| 68 |
+
else:
|
| 69 |
+
prompt = f"""You are a senior expert. Please convert the following specific question into a more abstract, fundamental question about principles and concepts.
|
| 70 |
+
This abstract question should help understand the basic concepts and principles in this field, rather than directly answering the specific question.
|
| 71 |
+
|
| 72 |
+
Specific question: {question}
|
| 73 |
+
|
| 74 |
+
Please generate an abstract question for retrieving relevant principles and background knowledge:
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
abstract_question = self.llm.generate(
|
| 79 |
+
prompt=prompt,
|
| 80 |
+
temperature=self.step_back_temperature,
|
| 81 |
+
max_tokens=200
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
abstract_question = abstract_question.strip()
|
| 85 |
+
|
| 86 |
+
if not abstract_question:
|
| 87 |
+
logger.warning("⚠️ 生成的抽象問題為空,使用原始問題")
|
| 88 |
+
return question
|
| 89 |
+
|
| 90 |
+
logger.info(f"✅ 生成抽象問題: '{abstract_question}'")
|
| 91 |
+
return abstract_question
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
logger.error(f"⚠️ 生成抽象問題時出錯: {e}")
|
| 95 |
+
return question
|
| 96 |
+
|
| 97 |
+
def _get_doc_id(self, doc: Dict) -> str:
|
| 98 |
+
"""
|
| 99 |
+
生成文檔的唯一標識符
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
doc: 文檔字典
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
唯一 ID
|
| 106 |
+
"""
|
| 107 |
+
metadata = doc.get("metadata", {})
|
| 108 |
+
content = doc.get("content", "")
|
| 109 |
+
|
| 110 |
+
if "arxiv_id" in metadata and "chunk_index" in metadata:
|
| 111 |
+
return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
|
| 112 |
+
elif "file_path" in metadata and "chunk_index" in metadata:
|
| 113 |
+
return f"{metadata['file_path']}_{metadata['chunk_index']}"
|
| 114 |
+
else:
|
| 115 |
+
content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
|
| 116 |
+
return f"doc_{content_hash}"
|
| 117 |
+
|
| 118 |
+
def _retrieve_direct(self, question: str, top_k: int, metadata_filter: Optional[Dict] = None) -> List[Dict]:
|
| 119 |
+
"""直接檢索原始問題(具體事實)"""
|
| 120 |
+
return self.vector_retriever.retrieve(
|
| 121 |
+
query=question,
|
| 122 |
+
top_k=top_k,
|
| 123 |
+
metadata_filter=metadata_filter
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def _retrieve_step_back(self, question: str, top_k: int, metadata_filter: Optional[Dict] = None) -> tuple:
|
| 127 |
+
"""Step-back 檢索(抽象原理)"""
|
| 128 |
+
abstract_question = self._generate_step_back_question(question)
|
| 129 |
+
results = self.vector_retriever.retrieve(
|
| 130 |
+
query=abstract_question,
|
| 131 |
+
top_k=top_k,
|
| 132 |
+
metadata_filter=metadata_filter
|
| 133 |
+
)
|
| 134 |
+
return results, abstract_question
|
| 135 |
+
|
| 136 |
+
def query(
|
| 137 |
+
self,
|
| 138 |
+
question: str,
|
| 139 |
+
top_k: int = 5,
|
| 140 |
+
metadata_filter: Optional[Dict] = None,
|
| 141 |
+
return_abstract_question: bool = False
|
| 142 |
+
) -> Dict:
|
| 143 |
+
"""
|
| 144 |
+
執行雙軌檢索(不生成答案)
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
question: 原始問題
|
| 148 |
+
top_k: 每軌返回的結果數量
|
| 149 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 150 |
+
return_abstract_question: 是否返回抽象問題
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
包含雙軌檢索結果的字典
|
| 154 |
+
"""
|
| 155 |
+
start_time = time.time()
|
| 156 |
+
|
| 157 |
+
if self.enable_parallel:
|
| 158 |
+
# 並行執行雙軌檢索
|
| 159 |
+
logger.info(f"🔄 並行執行雙軌檢索: '{question}'")
|
| 160 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 161 |
+
direct_future = executor.submit(
|
| 162 |
+
self._retrieve_direct, question, top_k, metadata_filter
|
| 163 |
+
)
|
| 164 |
+
step_back_future = executor.submit(
|
| 165 |
+
self._retrieve_step_back, question, top_k, metadata_filter
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
specific_results = direct_future.result()
|
| 169 |
+
abstract_results, abstract_question = step_back_future.result()
|
| 170 |
+
else:
|
| 171 |
+
# 串行執行
|
| 172 |
+
logger.info(f"🔄 串行執行雙軌檢索: '{question}'")
|
| 173 |
+
specific_results = self._retrieve_direct(question, top_k, metadata_filter)
|
| 174 |
+
abstract_results, abstract_question = self._retrieve_step_back(question, top_k, metadata_filter)
|
| 175 |
+
|
| 176 |
+
elapsed_time = time.time() - start_time
|
| 177 |
+
logger.info(
|
| 178 |
+
f"✅ 雙軌檢索完成(耗時: {elapsed_time:.2f}s)\n"
|
| 179 |
+
f" 具體事實: {len(specific_results)} 個結果\n"
|
| 180 |
+
f" 抽象原理: {len(abstract_results)} 個結果"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
return {
|
| 184 |
+
"specific_context": specific_results,
|
| 185 |
+
"abstract_context": abstract_results,
|
| 186 |
+
"abstract_question": abstract_question if return_abstract_question else None,
|
| 187 |
+
"question": question,
|
| 188 |
+
"elapsed_time": elapsed_time
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
def generate_answer(
|
| 192 |
+
self,
|
| 193 |
+
question: str,
|
| 194 |
+
formatter: PromptFormatter,
|
| 195 |
+
top_k: int = 5,
|
| 196 |
+
metadata_filter: Optional[Dict] = None,
|
| 197 |
+
document_type: str = "general",
|
| 198 |
+
return_abstract_question: bool = False
|
| 199 |
+
) -> Dict:
|
| 200 |
+
"""
|
| 201 |
+
完整的 Step-back RAG 流程:雙軌檢索 -> 生成答案
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
question: 原始問題
|
| 205 |
+
formatter: Prompt 格式化器
|
| 206 |
+
top_k: 每軌用於生成答案的文檔數量
|
| 207 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 208 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 209 |
+
return_abstract_question: 是否返回抽象問題
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
包含檢索結果、生成的答案和統計資訊的字典
|
| 213 |
+
"""
|
| 214 |
+
start_time = time.time()
|
| 215 |
+
|
| 216 |
+
# 第一步:雙軌檢索
|
| 217 |
+
retrieval_result = self.query(
|
| 218 |
+
question=question,
|
| 219 |
+
top_k=top_k,
|
| 220 |
+
metadata_filter=metadata_filter,
|
| 221 |
+
return_abstract_question=return_abstract_question
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
specific_results = retrieval_result["specific_context"]
|
| 225 |
+
abstract_results = retrieval_result["abstract_context"]
|
| 226 |
+
|
| 227 |
+
if not specific_results and not abstract_results:
|
| 228 |
+
return {
|
| 229 |
+
**retrieval_result,
|
| 230 |
+
"answer": "抱歉,未找到相關文檔來回答此問題。",
|
| 231 |
+
"formatted_context": None,
|
| 232 |
+
"answer_time": 0.0,
|
| 233 |
+
"total_time": retrieval_result["elapsed_time"]
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
# 第二步:格式化雙軌上下文
|
| 237 |
+
specific_context = formatter.format_context(
|
| 238 |
+
specific_results,
|
| 239 |
+
document_type=document_type
|
| 240 |
+
) if specific_results else "未找到相關的具體事實資料。"
|
| 241 |
+
|
| 242 |
+
abstract_context = formatter.format_context(
|
| 243 |
+
abstract_results,
|
| 244 |
+
document_type=document_type
|
| 245 |
+
) if abstract_results else "未找到相關的基礎原理資料。"
|
| 246 |
+
|
| 247 |
+
# 第三步:創建融合提示詞(關鍵步驟)
|
| 248 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 249 |
+
|
| 250 |
+
if is_chinese:
|
| 251 |
+
final_prompt = f"""你是一個資深專家。請結合以下兩類資訊來回答使用者的具體問題。
|
| 252 |
+
|
| 253 |
+
【基礎原理與背景】
|
| 254 |
+
{abstract_context}
|
| 255 |
+
|
| 256 |
+
【具體事實資料】
|
| 257 |
+
{specific_context}
|
| 258 |
+
|
| 259 |
+
使用者問題:{question}
|
| 260 |
+
|
| 261 |
+
請根據原理推導並結合事實,給出一個專業且具備邏輯的回答:
|
| 262 |
+
"""
|
| 263 |
+
else:
|
| 264 |
+
final_prompt = f"""You are a senior expert. Please answer the user's specific question by combining the following two types of information.
|
| 265 |
+
|
| 266 |
+
【Fundamental Principles and Background】
|
| 267 |
+
{abstract_context}
|
| 268 |
+
|
| 269 |
+
【Specific Facts and Data】
|
| 270 |
+
{specific_context}
|
| 271 |
+
|
| 272 |
+
User question: {question}
|
| 273 |
+
|
| 274 |
+
Please provide a professional and logical answer based on principles and facts:
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
# 第四步:生成回答
|
| 278 |
+
logger.info("🤖 生成回答中...")
|
| 279 |
+
answer_start = time.time()
|
| 280 |
+
try:
|
| 281 |
+
answer = self.llm.generate(
|
| 282 |
+
prompt=final_prompt,
|
| 283 |
+
temperature=self.answer_temperature,
|
| 284 |
+
max_tokens=2048
|
| 285 |
+
)
|
| 286 |
+
answer_time = time.time() - answer_start
|
| 287 |
+
logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
|
| 288 |
+
except Exception as e:
|
| 289 |
+
logger.error(f"❌ 生成回答時出錯: {e}")
|
| 290 |
+
answer = f"生成回答時出錯: {e}"
|
| 291 |
+
answer_time = time.time() - answer_start
|
| 292 |
+
|
| 293 |
+
total_time = time.time() - start_time
|
| 294 |
+
|
| 295 |
+
return {
|
| 296 |
+
**retrieval_result,
|
| 297 |
+
"answer": answer,
|
| 298 |
+
"formatted_context": {
|
| 299 |
+
"specific": specific_context,
|
| 300 |
+
"abstract": abstract_context
|
| 301 |
+
},
|
| 302 |
+
"answer_time": answer_time,
|
| 303 |
+
"total_time": total_time
|
| 304 |
+
}
|
| 305 |
+
|
src/subquery_rag.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sub-query Decomposition RAG:將複雜問題拆解成子問題後檢索
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
from .retrievers.reranker import RAGPipeline
|
| 6 |
+
from .prompt_formatter import PromptFormatter
|
| 7 |
+
from .llm_integration import OllamaLLM
|
| 8 |
+
import hashlib
|
| 9 |
+
import time
|
| 10 |
+
import logging
|
| 11 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SubQueryDecompositionRAG:
|
| 17 |
+
"""使用子問題拆解的 RAG 系統"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
rag_pipeline: RAGPipeline,
|
| 22 |
+
llm: OllamaLLM,
|
| 23 |
+
max_sub_queries: int = 3,
|
| 24 |
+
top_k_per_subquery: int = 5,
|
| 25 |
+
enable_parallel: bool = True
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
初始化 Sub-query Decomposition RAG
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
rag_pipeline: 現有的 RAG 管線實例
|
| 32 |
+
llm: LLM 實例(用於生成子問題)
|
| 33 |
+
max_sub_queries: 最多生成的子問題數量
|
| 34 |
+
top_k_per_subquery: 每個子問題檢索的結果數量
|
| 35 |
+
enable_parallel: 是否並行處理子查詢
|
| 36 |
+
"""
|
| 37 |
+
self.rag_pipeline = rag_pipeline
|
| 38 |
+
self.llm = llm
|
| 39 |
+
self.max_sub_queries = max_sub_queries
|
| 40 |
+
self.top_k_per_subquery = top_k_per_subquery
|
| 41 |
+
self.enable_parallel = enable_parallel
|
| 42 |
+
|
| 43 |
+
def _generate_sub_queries(self, question: str) -> List[str]:
|
| 44 |
+
"""
|
| 45 |
+
將原始問題拆解成子問題
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
question: 原始問題
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
子問題列表
|
| 52 |
+
"""
|
| 53 |
+
# 檢測語言
|
| 54 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 55 |
+
|
| 56 |
+
if is_chinese:
|
| 57 |
+
prompt = f"""你是一個專業助理。請將以下原始問題拆解成最多 {self.max_sub_queries} 個具體的子問題,以便進行資料搜尋。
|
| 58 |
+
每個子問題應專注於原始問題的一個特定面向。請以換行符號分隔問題。
|
| 59 |
+
|
| 60 |
+
原始問題: {question}
|
| 61 |
+
|
| 62 |
+
子問題清單:"""
|
| 63 |
+
else:
|
| 64 |
+
prompt = f"""You are a professional assistant. Please decompose the following original question into at most {self.max_sub_queries} specific sub-questions for information retrieval.
|
| 65 |
+
Each sub-question should focus on a specific aspect of the original question. Please separate questions with newlines.
|
| 66 |
+
|
| 67 |
+
Original question: {question}
|
| 68 |
+
|
| 69 |
+
Sub-question list:"""
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
response = self.llm.generate(
|
| 73 |
+
prompt=prompt,
|
| 74 |
+
temperature=0.3, # 降低溫度以獲得更穩定的結果
|
| 75 |
+
max_tokens=500
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# 解析子問題
|
| 79 |
+
sub_queries = [
|
| 80 |
+
q.strip()
|
| 81 |
+
for q in response.strip().split("\n")
|
| 82 |
+
if q.strip() and not q.strip().startswith("#")
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
# 移除編號前綴(如 "1. ", "1) " 等)
|
| 86 |
+
cleaned_queries = []
|
| 87 |
+
for q in sub_queries:
|
| 88 |
+
# 移除開頭的編號
|
| 89 |
+
q = q.lstrip("0123456789. )")
|
| 90 |
+
q = q.strip()
|
| 91 |
+
if q:
|
| 92 |
+
cleaned_queries.append(q)
|
| 93 |
+
|
| 94 |
+
# 限制數量
|
| 95 |
+
cleaned_queries = cleaned_queries[:self.max_sub_queries]
|
| 96 |
+
|
| 97 |
+
# 如果沒有生成子問題,使用原始問題
|
| 98 |
+
if not cleaned_queries:
|
| 99 |
+
logger.warning("⚠️ 未生成子問題,使用原始問題")
|
| 100 |
+
cleaned_queries = [question]
|
| 101 |
+
|
| 102 |
+
return cleaned_queries
|
| 103 |
+
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.error(f"⚠️ 生成子問題時出錯: {e}")
|
| 106 |
+
# 回退到原始問題
|
| 107 |
+
return [question]
|
| 108 |
+
|
| 109 |
+
def _get_doc_id(self, doc: Dict) -> str:
|
| 110 |
+
"""
|
| 111 |
+
生成文檔的唯一標識符
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
doc: 文檔字典
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
唯一 ID
|
| 118 |
+
"""
|
| 119 |
+
metadata = doc.get("metadata", {})
|
| 120 |
+
content = doc.get("content", "")
|
| 121 |
+
|
| 122 |
+
# 使用 metadata 中的唯一標識(如果有的話)
|
| 123 |
+
if "arxiv_id" in metadata and "chunk_index" in metadata:
|
| 124 |
+
return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
|
| 125 |
+
elif "file_path" in metadata and "chunk_index" in metadata:
|
| 126 |
+
return f"{metadata['file_path']}_{metadata['chunk_index']}"
|
| 127 |
+
else:
|
| 128 |
+
# 回退到內容的 hash
|
| 129 |
+
content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
|
| 130 |
+
return f"doc_{content_hash}"
|
| 131 |
+
|
| 132 |
+
def _retrieve_for_subquery(
|
| 133 |
+
self,
|
| 134 |
+
sub_query: str,
|
| 135 |
+
metadata_filter: Optional[Dict] = None
|
| 136 |
+
) -> List[Dict]:
|
| 137 |
+
"""
|
| 138 |
+
針對單個子問題進行檢索
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
sub_query: 子問題
|
| 142 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
檢索結果列表
|
| 146 |
+
"""
|
| 147 |
+
try:
|
| 148 |
+
results = self.rag_pipeline.query(
|
| 149 |
+
text=sub_query,
|
| 150 |
+
top_k=self.top_k_per_subquery,
|
| 151 |
+
metadata_filter=metadata_filter,
|
| 152 |
+
enable_rerank=True
|
| 153 |
+
)
|
| 154 |
+
return results
|
| 155 |
+
except Exception as e:
|
| 156 |
+
logger.error(f"⚠️ 檢索子問題 '{sub_query}' 時出錯: {e}")
|
| 157 |
+
return []
|
| 158 |
+
|
| 159 |
+
def _get_unique_documents(
|
| 160 |
+
self,
|
| 161 |
+
sub_queries: List[str],
|
| 162 |
+
metadata_filter: Optional[Dict] = None
|
| 163 |
+
) -> List[Dict]:
|
| 164 |
+
"""
|
| 165 |
+
針對所有子問題進行檢索,並移除重複的檔案
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
sub_queries: 子問題列表
|
| 169 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
去重後的文檔列表
|
| 173 |
+
"""
|
| 174 |
+
unique_docs = {}
|
| 175 |
+
|
| 176 |
+
if self.enable_parallel and len(sub_queries) > 1:
|
| 177 |
+
# 並行處理子查詢
|
| 178 |
+
logger.info(f"🔄 並行處理 {len(sub_queries)} 個子查詢...")
|
| 179 |
+
with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor:
|
| 180 |
+
future_to_query = {
|
| 181 |
+
executor.submit(self._retrieve_for_subquery, q, metadata_filter): q
|
| 182 |
+
for q in sub_queries
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
for future in as_completed(future_to_query):
|
| 186 |
+
sub_query = future_to_query[future]
|
| 187 |
+
try:
|
| 188 |
+
docs = future.result()
|
| 189 |
+
logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(docs)} 個結果")
|
| 190 |
+
for doc in docs:
|
| 191 |
+
doc_id = self._get_doc_id(doc)
|
| 192 |
+
if doc_id not in unique_docs:
|
| 193 |
+
unique_docs[doc_id] = doc
|
| 194 |
+
else:
|
| 195 |
+
# 如果已存在,保留分數更高的
|
| 196 |
+
existing_score = unique_docs[doc_id].get(
|
| 197 |
+
'rerank_score',
|
| 198 |
+
unique_docs[doc_id].get('hybrid_score', 0)
|
| 199 |
+
)
|
| 200 |
+
new_score = doc.get(
|
| 201 |
+
'rerank_score',
|
| 202 |
+
doc.get('hybrid_score', 0)
|
| 203 |
+
)
|
| 204 |
+
if new_score > existing_score:
|
| 205 |
+
unique_docs[doc_id] = doc
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
|
| 208 |
+
else:
|
| 209 |
+
# 串行處理
|
| 210 |
+
logger.info(f"🔄 串行處理 {len(sub_queries)} 個子查詢...")
|
| 211 |
+
for sub_query in sub_queries:
|
| 212 |
+
docs = self._retrieve_for_subquery(sub_query, metadata_filter)
|
| 213 |
+
logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(docs)} 個結果")
|
| 214 |
+
for doc in docs:
|
| 215 |
+
doc_id = self._get_doc_id(doc)
|
| 216 |
+
if doc_id not in unique_docs:
|
| 217 |
+
unique_docs[doc_id] = doc
|
| 218 |
+
else:
|
| 219 |
+
# 保留分數更高的
|
| 220 |
+
existing_score = unique_docs[doc_id].get(
|
| 221 |
+
'rerank_score',
|
| 222 |
+
unique_docs[doc_id].get('hybrid_score', 0)
|
| 223 |
+
)
|
| 224 |
+
new_score = doc.get(
|
| 225 |
+
'rerank_score',
|
| 226 |
+
doc.get('hybrid_score', 0)
|
| 227 |
+
)
|
| 228 |
+
if new_score > existing_score:
|
| 229 |
+
unique_docs[doc_id] = doc
|
| 230 |
+
|
| 231 |
+
# 按分數排序
|
| 232 |
+
result_list = list(unique_docs.values())
|
| 233 |
+
result_list.sort(
|
| 234 |
+
key=lambda x: x.get('rerank_score', x.get('hybrid_score', 0)),
|
| 235 |
+
reverse=True
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
return result_list
|
| 239 |
+
|
| 240 |
+
def query(
|
| 241 |
+
self,
|
| 242 |
+
question: str,
|
| 243 |
+
top_k: int = 5,
|
| 244 |
+
metadata_filter: Optional[Dict] = None,
|
| 245 |
+
return_sub_queries: bool = False
|
| 246 |
+
) -> Dict:
|
| 247 |
+
"""
|
| 248 |
+
執行 Sub-query Decomposition RAG 查詢
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
question: 原始問題
|
| 252 |
+
top_k: 返回前 k 個結果
|
| 253 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 254 |
+
return_sub_queries: 是否在結果中包含子問題列表
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
包含檢索結果和統計資訊的字典
|
| 258 |
+
"""
|
| 259 |
+
start_time = time.time()
|
| 260 |
+
|
| 261 |
+
# 第一步:產生子問題
|
| 262 |
+
logger.info(f"🔍 拆解問題: '{question}'")
|
| 263 |
+
sub_queries = self._generate_sub_queries(question)
|
| 264 |
+
logger.info(f"✅ 生成 {len(sub_queries)} 個子問題:")
|
| 265 |
+
for i, sq in enumerate(sub_queries, 1):
|
| 266 |
+
logger.info(f" {i}. {sq}")
|
| 267 |
+
|
| 268 |
+
# 第二步:檢索並去重
|
| 269 |
+
logger.info(f"📚 檢索相關文檔...")
|
| 270 |
+
docs = self._get_unique_documents(sub_queries, metadata_filter)
|
| 271 |
+
logger.info(f"✅ 找到 {len(docs)} 個唯一文檔(去重後)")
|
| 272 |
+
|
| 273 |
+
# 第三步:返回前 top_k 個結果
|
| 274 |
+
final_results = docs[:top_k]
|
| 275 |
+
|
| 276 |
+
elapsed_time = time.time() - start_time
|
| 277 |
+
|
| 278 |
+
result = {
|
| 279 |
+
"results": final_results,
|
| 280 |
+
"total_docs_found": len(docs),
|
| 281 |
+
"sub_queries": sub_queries if return_sub_queries else None,
|
| 282 |
+
"elapsed_time": elapsed_time
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
return result
|
| 286 |
+
|
| 287 |
+
def generate_answer(
|
| 288 |
+
self,
|
| 289 |
+
question: str,
|
| 290 |
+
formatter: PromptFormatter,
|
| 291 |
+
top_k: int = 5,
|
| 292 |
+
metadata_filter: Optional[Dict] = None,
|
| 293 |
+
document_type: str = "general",
|
| 294 |
+
return_sub_queries: bool = False
|
| 295 |
+
) -> Dict:
|
| 296 |
+
"""
|
| 297 |
+
完整的 Sub-query Decomposition RAG 流程:檢索 + 生成答案
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
question: 原始問題
|
| 301 |
+
formatter: Prompt 格式化器
|
| 302 |
+
top_k: 返回前 k 個結果用於生成答案
|
| 303 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 304 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 305 |
+
return_sub_queries: 是否在結果中包含子問題列表
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
包含檢索結果、生成的答案和統計資訊的字典
|
| 309 |
+
"""
|
| 310 |
+
# 檢索
|
| 311 |
+
retrieval_result = self.query(
|
| 312 |
+
question=question,
|
| 313 |
+
top_k=top_k,
|
| 314 |
+
metadata_filter=metadata_filter,
|
| 315 |
+
return_sub_queries=return_sub_queries
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
if not retrieval_result["results"]:
|
| 319 |
+
return {
|
| 320 |
+
**retrieval_result,
|
| 321 |
+
"answer": "抱歉,未找到相關文檔來回答此問題。",
|
| 322 |
+
"formatted_context": None
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
# 格式化上下文
|
| 326 |
+
formatted_context = formatter.format_context(
|
| 327 |
+
retrieval_result["results"],
|
| 328 |
+
document_type=document_type
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# 創建 prompt
|
| 332 |
+
prompt = formatter.create_prompt(
|
| 333 |
+
question,
|
| 334 |
+
formatted_context,
|
| 335 |
+
document_type=document_type
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# 生成回答
|
| 339 |
+
logger.info("🤖 生成回答中...")
|
| 340 |
+
answer_start = time.time()
|
| 341 |
+
try:
|
| 342 |
+
answer = self.llm.generate(
|
| 343 |
+
prompt=prompt,
|
| 344 |
+
temperature=0.7,
|
| 345 |
+
max_tokens=2048
|
| 346 |
+
)
|
| 347 |
+
answer_time = time.time() - answer_start
|
| 348 |
+
logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
|
| 349 |
+
except Exception as e:
|
| 350 |
+
logger.error(f"❌ 生成回答時出錯: {e}")
|
| 351 |
+
answer = f"生成回答時出錯: {e}"
|
| 352 |
+
answer_time = time.time() - answer_start
|
| 353 |
+
|
| 354 |
+
return {
|
| 355 |
+
**retrieval_result,
|
| 356 |
+
"answer": answer,
|
| 357 |
+
"formatted_context": formatted_context,
|
| 358 |
+
"answer_time": answer_time,
|
| 359 |
+
"total_time": retrieval_result["elapsed_time"] + answer_time
|
| 360 |
+
}
|
| 361 |
+
|
src/triple_hybrid_rag.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Triple Hybrid RAG:融合 SubQuery + HyDE + Step-back Prompting
|
| 3 |
+
結合三種技術的優勢,實現最強大的 RAG 系統
|
| 4 |
+
"""
|
| 5 |
+
from typing import List, Dict, Optional
|
| 6 |
+
from .retrievers.reranker import RAGPipeline
|
| 7 |
+
from .retrievers.vector_retriever import VectorRetriever
|
| 8 |
+
from .prompt_formatter import PromptFormatter
|
| 9 |
+
from .llm_integration import OllamaLLM
|
| 10 |
+
import hashlib
|
| 11 |
+
import time
|
| 12 |
+
import logging
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TripleHybridRAG:
|
| 19 |
+
"""融合 SubQuery + HyDE + Step-back 的三重混合 RAG 系統"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
rag_pipeline: RAGPipeline,
|
| 24 |
+
vector_retriever: VectorRetriever,
|
| 25 |
+
llm: OllamaLLM,
|
| 26 |
+
max_sub_queries: int = 3,
|
| 27 |
+
top_k_per_subquery: int = 5,
|
| 28 |
+
hypothetical_length: int = 200,
|
| 29 |
+
temperature_subquery: float = 0.3,
|
| 30 |
+
temperature_hyde: float = 0.7,
|
| 31 |
+
temperature_stepback: float = 0.3,
|
| 32 |
+
answer_temperature: float = 0.7,
|
| 33 |
+
enable_parallel: bool = True
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
初始化三重混合 RAG
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
rag_pipeline: RAG 管線實例
|
| 40 |
+
vector_retriever: 向量檢索器
|
| 41 |
+
llm: LLM 實例
|
| 42 |
+
max_sub_queries: 最多生成的子問題數量
|
| 43 |
+
top_k_per_subquery: 每個子問題檢索的結果數量
|
| 44 |
+
hypothetical_length: 假設性文檔目標長度(字符數)
|
| 45 |
+
temperature_subquery: 生成子問題的溫度(較低,更穩定)
|
| 46 |
+
temperature_hyde: 生成假設性文檔的溫度(較高,更多專業術語)
|
| 47 |
+
temperature_stepback: 生成抽象問題的溫度(較低,更穩定)
|
| 48 |
+
answer_temperature: 生成答案的溫度
|
| 49 |
+
enable_parallel: 是否並行處理
|
| 50 |
+
"""
|
| 51 |
+
self.rag_pipeline = rag_pipeline
|
| 52 |
+
self.vector_retriever = vector_retriever
|
| 53 |
+
self.llm = llm
|
| 54 |
+
self.max_sub_queries = max_sub_queries
|
| 55 |
+
self.top_k_per_subquery = top_k_per_subquery
|
| 56 |
+
self.hypothetical_length = hypothetical_length
|
| 57 |
+
self.temperature_subquery = temperature_subquery
|
| 58 |
+
self.temperature_hyde = temperature_hyde
|
| 59 |
+
self.temperature_stepback = temperature_stepback
|
| 60 |
+
self.answer_temperature = answer_temperature
|
| 61 |
+
self.enable_parallel = enable_parallel
|
| 62 |
+
|
| 63 |
+
def _generate_sub_queries(self, question: str) -> List[str]:
|
| 64 |
+
"""生成子問題(SubQuery)"""
|
| 65 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 66 |
+
|
| 67 |
+
if is_chinese:
|
| 68 |
+
prompt = f"""你是一個專業助理。請將以下原始問題拆解成最多 {self.max_sub_queries} 個具體的子問題,以便進行資料搜尋。
|
| 69 |
+
每個子問題應專注於原始問題的一個特定面向。請以換行符號分隔問題。
|
| 70 |
+
|
| 71 |
+
原始問題: {question}
|
| 72 |
+
|
| 73 |
+
子問題清單:"""
|
| 74 |
+
else:
|
| 75 |
+
prompt = f"""You are a professional assistant. Please decompose the following original question into at most {self.max_sub_queries} specific sub-questions for information retrieval.
|
| 76 |
+
Each sub-question should focus on a specific aspect of the original question. Please separate questions with newlines.
|
| 77 |
+
|
| 78 |
+
Original question: {question}
|
| 79 |
+
|
| 80 |
+
Sub-question list:"""
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
response = self.llm.generate(
|
| 84 |
+
prompt=prompt,
|
| 85 |
+
temperature=self.temperature_subquery,
|
| 86 |
+
max_tokens=500
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
sub_queries = [
|
| 90 |
+
q.strip()
|
| 91 |
+
for q in response.strip().split("\n")
|
| 92 |
+
if q.strip() and not q.strip().startswith("#")
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
# 移除編號前綴
|
| 96 |
+
cleaned_queries = []
|
| 97 |
+
for q in sub_queries:
|
| 98 |
+
q = q.lstrip("0123456789. )")
|
| 99 |
+
q = q.strip()
|
| 100 |
+
if q:
|
| 101 |
+
cleaned_queries.append(q)
|
| 102 |
+
|
| 103 |
+
cleaned_queries = cleaned_queries[:self.max_sub_queries]
|
| 104 |
+
|
| 105 |
+
if not cleaned_queries:
|
| 106 |
+
logger.warning("⚠️ 未生成子問題,使用原始問題")
|
| 107 |
+
cleaned_queries = [question]
|
| 108 |
+
|
| 109 |
+
return cleaned_queries
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.error(f"⚠️ 生成子問題時出錯: {e}")
|
| 113 |
+
return [question]
|
| 114 |
+
|
| 115 |
+
def _generate_hypothetical_document(self, sub_query: str) -> str:
|
| 116 |
+
"""為子問題生成假設性文檔(HyDE)"""
|
| 117 |
+
is_chinese = PromptFormatter.detect_language(sub_query) == "zh"
|
| 118 |
+
|
| 119 |
+
if is_chinese:
|
| 120 |
+
prompt = f"""請針對以下問題,寫出一段約 {self.hypothetical_length} 字的專業技術檔案內容。
|
| 121 |
+
這段內容應包含該領域常見的專業術語與原理說明,以便用於後續的語義檢索。
|
| 122 |
+
請使用專業的術語和概念,即使你對某些細節不確定,也要包含相關的專業詞彙。
|
| 123 |
+
|
| 124 |
+
問題: {sub_query}
|
| 125 |
+
|
| 126 |
+
專業��術內容:"""
|
| 127 |
+
else:
|
| 128 |
+
prompt = f"""Please write a professional technical document of approximately {self.hypothetical_length} words in response to the following question.
|
| 129 |
+
This content should include common professional terminology and principle explanations in this field, to be used for subsequent semantic retrieval.
|
| 130 |
+
Please use professional terms and concepts, and include relevant professional vocabulary even if you are uncertain about some details.
|
| 131 |
+
|
| 132 |
+
Question: {sub_query}
|
| 133 |
+
|
| 134 |
+
Professional technical content:"""
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
hypothetical_doc = self.llm.generate(
|
| 138 |
+
prompt=prompt,
|
| 139 |
+
temperature=self.temperature_hyde,
|
| 140 |
+
max_tokens=500
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
hypothetical_doc = hypothetical_doc.strip()
|
| 144 |
+
|
| 145 |
+
if not hypothetical_doc:
|
| 146 |
+
logger.warning(f"⚠️ 子問題 '{sub_query}' 的假設性文檔為空,使用子問題本身")
|
| 147 |
+
return sub_query
|
| 148 |
+
|
| 149 |
+
return hypothetical_doc
|
| 150 |
+
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.error(f"⚠️ 生成假設性文檔時出錯: {e}")
|
| 153 |
+
return sub_query
|
| 154 |
+
|
| 155 |
+
def _generate_step_back_question(self, question: str) -> str:
|
| 156 |
+
"""生成 Step-back 抽象問題"""
|
| 157 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 158 |
+
|
| 159 |
+
if is_chinese:
|
| 160 |
+
prompt = f"""你是一個資深專家。請將以下具體問題轉換為一個更抽象、更基礎的原理性問題。
|
| 161 |
+
這個抽象問題應該幫助理解該領域的基礎概念和原理,而不是直接回答具體問題。
|
| 162 |
+
|
| 163 |
+
具體問題: {question}
|
| 164 |
+
|
| 165 |
+
請生成一個抽象問題,用於檢索相關的原理和背景知識:
|
| 166 |
+
"""
|
| 167 |
+
else:
|
| 168 |
+
prompt = f"""You are a senior expert. Please convert the following specific question into a more abstract, fundamental question about principles and concepts.
|
| 169 |
+
This abstract question should help understand the basic concepts and principles in this field, rather than directly answering the specific question.
|
| 170 |
+
|
| 171 |
+
Specific question: {question}
|
| 172 |
+
|
| 173 |
+
Please generate an abstract question for retrieving relevant principles and background knowledge:
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
abstract_question = self.llm.generate(
|
| 178 |
+
prompt=prompt,
|
| 179 |
+
temperature=self.temperature_stepback,
|
| 180 |
+
max_tokens=200
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
abstract_question = abstract_question.strip()
|
| 184 |
+
|
| 185 |
+
if not abstract_question:
|
| 186 |
+
logger.warning("⚠️ 生成的抽象問題為空,使用原始問題")
|
| 187 |
+
return question
|
| 188 |
+
|
| 189 |
+
return abstract_question
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"⚠️ 生成抽象問題時出錯: {e}")
|
| 193 |
+
return question
|
| 194 |
+
|
| 195 |
+
def _get_doc_id(self, doc: Dict) -> str:
|
| 196 |
+
"""生成文檔的唯一標識符"""
|
| 197 |
+
metadata = doc.get("metadata", {})
|
| 198 |
+
content = doc.get("content", "")
|
| 199 |
+
|
| 200 |
+
if "arxiv_id" in metadata and "chunk_index" in metadata:
|
| 201 |
+
return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
|
| 202 |
+
elif "file_path" in metadata and "chunk_index" in metadata:
|
| 203 |
+
return f"{metadata['file_path']}_{metadata['chunk_index']}"
|
| 204 |
+
else:
|
| 205 |
+
content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
|
| 206 |
+
return f"doc_{content_hash}"
|
| 207 |
+
|
| 208 |
+
def _process_subquery_with_hyde(
|
| 209 |
+
self,
|
| 210 |
+
sub_query: str,
|
| 211 |
+
metadata_filter: Optional[Dict] = None
|
| 212 |
+
) -> tuple:
|
| 213 |
+
"""處理單個子問題:生成假設性文檔並檢索"""
|
| 214 |
+
try:
|
| 215 |
+
hypothetical_doc = self._generate_hypothetical_document(sub_query)
|
| 216 |
+
results = self.vector_retriever.retrieve(
|
| 217 |
+
query=hypothetical_doc,
|
| 218 |
+
top_k=self.top_k_per_subquery,
|
| 219 |
+
metadata_filter=metadata_filter
|
| 220 |
+
)
|
| 221 |
+
return results, hypothetical_doc
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
|
| 224 |
+
return [], ""
|
| 225 |
+
|
| 226 |
+
def query(
|
| 227 |
+
self,
|
| 228 |
+
question: str,
|
| 229 |
+
top_k: int = 5,
|
| 230 |
+
metadata_filter: Optional[Dict] = None,
|
| 231 |
+
return_sub_queries: bool = False,
|
| 232 |
+
return_hypothetical: bool = False,
|
| 233 |
+
return_abstract_question: bool = False
|
| 234 |
+
) -> Dict:
|
| 235 |
+
"""
|
| 236 |
+
執行三重混合 RAG 檢索
|
| 237 |
+
|
| 238 |
+
流程:
|
| 239 |
+
1. 拆解成子問題(SubQuery)
|
| 240 |
+
2. 對每個子問題生成假設性文檔並檢索(HyDE)
|
| 241 |
+
3. 直接檢索原始問題(具體事實)
|
| 242 |
+
4. 生成抽象問題並檢索(Step-back,抽象原理)
|
| 243 |
+
5. 合併所有結果並去重
|
| 244 |
+
"""
|
| 245 |
+
start_time = time.time()
|
| 246 |
+
|
| 247 |
+
# 第一步:生成子問題
|
| 248 |
+
logger.info(f"🔍 [SubQuery] 拆解問題: '{question}'")
|
| 249 |
+
sub_queries = self._generate_sub_queries(question)
|
| 250 |
+
logger.info(f"✅ 生成 {len(sub_queries)} 個子問題")
|
| 251 |
+
|
| 252 |
+
# 第二步:為每個子問題生成假設性文檔並檢索(HyDE)
|
| 253 |
+
logger.info(f"📚 [HyDE] 為每個子問題生成假設性文檔並檢索...")
|
| 254 |
+
subquery_results = []
|
| 255 |
+
hypothetical_docs = {}
|
| 256 |
+
|
| 257 |
+
if self.enable_parallel and len(sub_queries) > 1:
|
| 258 |
+
with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor:
|
| 259 |
+
future_to_query = {
|
| 260 |
+
executor.submit(self._process_subquery_with_hyde, sq, metadata_filter): sq
|
| 261 |
+
for sq in sub_queries
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
for future in as_completed(future_to_query):
|
| 265 |
+
sub_query = future_to_query[future]
|
| 266 |
+
try:
|
| 267 |
+
results, hypo_doc = future.result()
|
| 268 |
+
hypothetical_docs[sub_query] = hypo_doc
|
| 269 |
+
subquery_results.extend(results)
|
| 270 |
+
except Exception as e:
|
| 271 |
+
logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
|
| 272 |
+
else:
|
| 273 |
+
for sub_query in sub_queries:
|
| 274 |
+
results, hypo_doc = self._process_subquery_with_hyde(sub_query, metadata_filter)
|
| 275 |
+
hypothetical_docs[sub_query] = hypo_doc
|
| 276 |
+
subquery_results.extend(results)
|
| 277 |
+
|
| 278 |
+
# 第三步:Step-back 雙軌檢索
|
| 279 |
+
logger.info(f"🔍 [Step-back] 執行雙軌檢索...")
|
| 280 |
+
|
| 281 |
+
if self.enable_parallel:
|
| 282 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 283 |
+
direct_future = executor.submit(
|
| 284 |
+
self.vector_retriever.retrieve,
|
| 285 |
+
question, top_k, metadata_filter
|
| 286 |
+
)
|
| 287 |
+
abstract_question = self._generate_step_back_question(question)
|
| 288 |
+
step_back_future = executor.submit(
|
| 289 |
+
self.vector_retriever.retrieve,
|
| 290 |
+
abstract_question, top_k, metadata_filter
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
specific_results = direct_future.result()
|
| 294 |
+
abstract_results = step_back_future.result()
|
| 295 |
+
else:
|
| 296 |
+
specific_results = self.vector_retriever.retrieve(
|
| 297 |
+
query=question,
|
| 298 |
+
top_k=top_k,
|
| 299 |
+
metadata_filter=metadata_filter
|
| 300 |
+
)
|
| 301 |
+
abstract_question = self._generate_step_back_question(question)
|
| 302 |
+
abstract_results = self.vector_retriever.retrieve(
|
| 303 |
+
query=abstract_question,
|
| 304 |
+
top_k=top_k,
|
| 305 |
+
metadata_filter=metadata_filter
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# 第四步:合併所有結果並去重
|
| 309 |
+
logger.info(f"🔄 合併並去重所有檢索結果...")
|
| 310 |
+
all_results = subquery_results + specific_results + abstract_results
|
| 311 |
+
unique_docs = {}
|
| 312 |
+
|
| 313 |
+
for doc in all_results:
|
| 314 |
+
doc_id = self._get_doc_id(doc)
|
| 315 |
+
if doc_id not in unique_docs:
|
| 316 |
+
unique_docs[doc_id] = doc
|
| 317 |
+
else:
|
| 318 |
+
# 保留分數更高的
|
| 319 |
+
existing_score = unique_docs[doc_id].get('score', 0)
|
| 320 |
+
new_score = doc.get('score', 0)
|
| 321 |
+
if new_score > existing_score:
|
| 322 |
+
unique_docs[doc_id] = doc
|
| 323 |
+
|
| 324 |
+
# 排序並返回前 top_k
|
| 325 |
+
result_list = list(unique_docs.values())
|
| 326 |
+
result_list.sort(key=lambda x: x.get('score', 0), reverse=True)
|
| 327 |
+
final_results = result_list[:top_k]
|
| 328 |
+
|
| 329 |
+
elapsed_time = time.time() - start_time
|
| 330 |
+
logger.info(
|
| 331 |
+
f"✅ 三重混合檢索完成(耗時: {elapsed_time:.2f}s)\n"
|
| 332 |
+
f" 子問題檢索: {len(subquery_results)} 個結果\n"
|
| 333 |
+
f" 具體事實: {len(specific_results)} 個結果\n"
|
| 334 |
+
f" 抽象原理: {len(abstract_results)} 個結果\n"
|
| 335 |
+
f" 去重後總計: {len(result_list)} 個,返回前 {len(final_results)} 個"
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
return {
|
| 339 |
+
"results": final_results,
|
| 340 |
+
"total_docs_found": len(result_list),
|
| 341 |
+
"sub_queries": sub_queries if return_sub_queries else None,
|
| 342 |
+
"hypothetical_documents": hypothetical_docs if return_hypothetical else None,
|
| 343 |
+
"abstract_question": abstract_question if return_abstract_question else None,
|
| 344 |
+
"subquery_results": subquery_results,
|
| 345 |
+
"specific_context": specific_results,
|
| 346 |
+
"abstract_context": abstract_results,
|
| 347 |
+
"question": question,
|
| 348 |
+
"elapsed_time": elapsed_time
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
def generate_answer(
|
| 352 |
+
self,
|
| 353 |
+
question: str,
|
| 354 |
+
formatter: PromptFormatter,
|
| 355 |
+
top_k: int = 5,
|
| 356 |
+
metadata_filter: Optional[Dict] = None,
|
| 357 |
+
document_type: str = "general",
|
| 358 |
+
return_sub_queries: bool = False,
|
| 359 |
+
return_hypothetical: bool = False,
|
| 360 |
+
return_abstract_question: bool = False
|
| 361 |
+
) -> Dict:
|
| 362 |
+
"""
|
| 363 |
+
完整的三重混合 RAG 流程:檢索 + 生成答案
|
| 364 |
+
"""
|
| 365 |
+
start_time = time.time()
|
| 366 |
+
|
| 367 |
+
# 檢索
|
| 368 |
+
retrieval_result = self.query(
|
| 369 |
+
question=question,
|
| 370 |
+
top_k=top_k,
|
| 371 |
+
metadata_filter=metadata_filter,
|
| 372 |
+
return_sub_queries=return_sub_queries,
|
| 373 |
+
return_hypothetical=return_hypothetical,
|
| 374 |
+
return_abstract_question=return_abstract_question
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if not retrieval_result["results"]:
|
| 378 |
+
return {
|
| 379 |
+
**retrieval_result,
|
| 380 |
+
"answer": "抱歉,未找到相關文檔來回答此問題。",
|
| 381 |
+
"formatted_context": None,
|
| 382 |
+
"answer_time": 0.0,
|
| 383 |
+
"total_time": retrieval_result["elapsed_time"]
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
# 格式化三類上下文
|
| 387 |
+
subquery_context = formatter.format_context(
|
| 388 |
+
retrieval_result["subquery_results"][:top_k],
|
| 389 |
+
document_type=document_type
|
| 390 |
+
) if retrieval_result.get("subquery_results") else "未找到相關的子問題檢索結果。"
|
| 391 |
+
|
| 392 |
+
specific_context = formatter.format_context(
|
| 393 |
+
retrieval_result["specific_context"],
|
| 394 |
+
document_type=document_type
|
| 395 |
+
) if retrieval_result.get("specific_context") else "未找到相關的具體事實資料。"
|
| 396 |
+
|
| 397 |
+
abstract_context = formatter.format_context(
|
| 398 |
+
retrieval_result["abstract_context"],
|
| 399 |
+
document_type=document_type
|
| 400 |
+
) if retrieval_result.get("abstract_context") else "未找到相關的基礎原理資料。"
|
| 401 |
+
|
| 402 |
+
# 創建融合提示詞(關鍵步驟)
|
| 403 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 404 |
+
|
| 405 |
+
if is_chinese:
|
| 406 |
+
final_prompt = f"""你是一個資深專家。請結合以下三類資訊來回答使用者的具體問題。
|
| 407 |
+
|
| 408 |
+
【基礎原理與背景】(來自 Step-back 抽象問題檢索)
|
| 409 |
+
{abstract_context}
|
| 410 |
+
|
| 411 |
+
【具體事實資料】(來自直接問題檢索)
|
| 412 |
+
{specific_context}
|
| 413 |
+
|
| 414 |
+
【子問題相關資料】(來自 SubQuery + HyDE 檢索)
|
| 415 |
+
{subquery_context}
|
| 416 |
+
|
| 417 |
+
使用者問題:{question}
|
| 418 |
+
|
| 419 |
+
請根據原理推導、結合具體事實,並參考子問題的相關資料,給出一個專業、全面且具備邏輯的回答:
|
| 420 |
+
"""
|
| 421 |
+
else:
|
| 422 |
+
final_prompt = f"""You are a senior expert. Please answer the user's specific question by combining the following three types of information.
|
| 423 |
+
|
| 424 |
+
【Fundamental Principles and Background】(from Step-back abstract question retrieval)
|
| 425 |
+
{abstract_context}
|
| 426 |
+
|
| 427 |
+
【Specific Facts and Data】(from direct question retrieval)
|
| 428 |
+
{specific_context}
|
| 429 |
+
|
| 430 |
+
【Sub-question Related Information】(from SubQuery + HyDE retrieval)
|
| 431 |
+
{subquery_context}
|
| 432 |
+
|
| 433 |
+
User question: {question}
|
| 434 |
+
|
| 435 |
+
Please provide a professional, comprehensive, and logical answer based on principles, facts, and sub-question related information:
|
| 436 |
+
"""
|
| 437 |
+
|
| 438 |
+
# 生成回答
|
| 439 |
+
logger.info("🤖 生成回答中...")
|
| 440 |
+
answer_start = time.time()
|
| 441 |
+
try:
|
| 442 |
+
answer = self.llm.generate(
|
| 443 |
+
prompt=final_prompt,
|
| 444 |
+
temperature=self.answer_temperature,
|
| 445 |
+
max_tokens=2048
|
| 446 |
+
)
|
| 447 |
+
answer_time = time.time() - answer_start
|
| 448 |
+
logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
|
| 449 |
+
except Exception as e:
|
| 450 |
+
logger.error(f"❌ 生成回答時出錯: {e}")
|
| 451 |
+
answer = f"生成回答時出錯: {e}"
|
| 452 |
+
answer_time = time.time() - answer_start
|
| 453 |
+
|
| 454 |
+
total_time = time.time() - start_time
|
| 455 |
+
|
| 456 |
+
return {
|
| 457 |
+
**retrieval_result,
|
| 458 |
+
"answer": answer,
|
| 459 |
+
"formatted_context": {
|
| 460 |
+
"subquery": subquery_context,
|
| 461 |
+
"specific": specific_context,
|
| 462 |
+
"abstract": abstract_context
|
| 463 |
+
},
|
| 464 |
+
"answer_time": answer_time,
|
| 465 |
+
"total_time": total_time
|
| 466 |
+
}
|
| 467 |
+
|