wenlianghuang commited on
Commit
979763a
·
1 Parent(s): ee8f3fd

combine src of advanced RAG

Browse files
deep_agent_rag/rag/private_file_rag.py CHANGED
@@ -24,34 +24,26 @@ from langchain_core.messages import HumanMessage
24
  from .llm_adapter import LangChainLLMAdapter
25
  from .adaptive_rag_selector import AdaptiveRAGSelector, RAGMethod
26
 
27
- # 添加 Learn_RAG 到 Python 路徑
28
- # 計算 Learn_RAG 的路徑(與 Deep_Agentic_AI_Tool 在同一目錄下)
29
  current_file = Path(__file__).resolve()
30
  # 從 deep_agent_rag/rag/private_file_rag.py 向上找到 Deep_Agentic_AI_Tool 根目錄
 
31
  deep_agent_root = current_file.parent.parent.parent.parent
32
- learn_rag_path = deep_agent_root.parent / "Learn_RAG"
33
-
34
- # 如果 Learn_RAG 不在預期位置,嘗試其他可能的位置
35
- if not learn_rag_path.exists():
36
- # 嘗試當前工作目錄的父目錄
37
- cwd = Path.cwd()
38
- learn_rag_path = cwd.parent / "Learn_RAG"
39
-
40
- if not learn_rag_path.exists():
41
- # 嘗試直接使用絕對路徑
42
- learn_rag_path = Path("/Users/matthuang/Desktop/Learn_RAG")
43
 
44
- # Learn_RAG 目錄添加到 Python 路徑這樣可以導入 src 模組
45
- # 注意:需要將 Learn_RAG 目錄本身添加到路徑,因為 src 模組在 Learn_RAG/src/
46
- if learn_rag_path.exists() and learn_rag_path.is_dir():
47
- if str(learn_rag_path) not in sys.path:
48
- sys.path.insert(0, str(learn_rag_path))
49
- print(f"✓ 找到 Learn_RAG 項目: {learn_rag_path}")
50
- print(f" Python 路徑已添加: {learn_rag_path}")
 
 
51
  else:
52
- print(f"⚠️ 無法找到 Learn_RAG 目")
53
- print(f" 嘗試的路徑: {learn_rag_path}")
54
- print(f" 請確保 Learn_RAG 項目: {deep_agent_root.parent / 'Learn_RAG'}")
55
 
56
  # 嘗試導入 Learn_RAG 模組
57
  # 注意:document_processor.py 在頂層導入了 arxiv,所以需要先安裝依賴
@@ -78,13 +70,13 @@ try:
78
 
79
  if missing_deps:
80
  print(f"⚠️ 缺少以下依賴包: {', '.join(missing_deps)}")
81
- print(f"\n💡 請安裝 Learn_RAG 項目的依賴:")
82
  print(f" 方法 1: 使用 pip")
83
  print(f" pip install {' '.join(missing_deps)}")
84
- print(f"\n 方法 2: 使用 uv (推薦,如果 Learn_RAG 使用 uv)")
85
- print(f" cd {learn_rag_path}")
86
  print(f" uv sync")
87
- print(f"\n 方法 3: 安裝所有 Learn_RAG 依賴")
88
  print(f" pip install arxiv langchain-community langchain-text-splitters chromadb sentence-transformers rank-bm25 pypdf docx2txt langchain-experimental")
89
  LEARN_RAG_AVAILABLE = False
90
  else:
@@ -104,22 +96,23 @@ try:
104
  # 不再需要導入 OllamaLLM,因為我們使用 Deep_Agentic_AI_Tool 的統一 LLM 系統(get_llm())
105
  # from src.llm_integration import OllamaLLM
106
  LEARN_RAG_AVAILABLE = True
107
- print("✓ 成功導入 Learn_RAG 模組(包含進階 RAG 方法)")
108
 
109
  except ImportError as e:
110
  error_msg = str(e)
111
- print(f"⚠️ 無法導入 Learn_RAG 模組: {error_msg}")
112
- print(f"\n💡 請安裝 Learn_RAG 項目的依賴:")
113
  print(f" pip install arxiv langchain-community langchain-text-splitters chromadb sentence-transformers rank-bm25 pypdf docx2txt langchain-experimental")
114
  print(f"\n 或者:")
115
- print(f" cd {learn_rag_path}")
116
  print(f" uv sync")
117
  LEARN_RAG_AVAILABLE = False
118
  except Exception as e:
119
  error_msg = str(e)
120
- print(f"⚠️ 導入 Learn_RAG 模組時發生錯誤: {error_msg}")
121
  print(f" 當前 Python 路徑: {sys.path[:3]}")
122
- print(f" Learn_RAG 路徑: {learn_rag_path}")
 
123
  LEARN_RAG_AVAILABLE = False
124
 
125
 
@@ -1236,134 +1229,6 @@ def reset_private_rag_instance():
1236
  """重置全局實例"""
1237
  global _private_rag_instance
1238
  _private_rag_instance = None
1239
-
1240
-
1241
- """
1242
- 私有文件 RAG 系統
1243
- 集成 Learn_RAG 的功能,支持上傳私有文件(PDF、DOCX、TXT)並使用 RAG 回答問題
1244
-
1245
- LLM 使用策略:
1246
- - 優先使用 Groq API(如果配置了 API 金鑰)
1247
- - 其次使用 Ollama(如果服務正在運行)
1248
- - 最後使用 MLX 本地模型(作為備選方案)
1249
- """
1250
- import os
1251
- import sys
1252
- import time
1253
- from pathlib import Path
1254
- from typing import Optional, Dict, List, Tuple
1255
- import tempfile
1256
- import shutil
1257
-
1258
- # 導入 Deep_Agentic_AI_Tool 的 LLM 工具
1259
- # 這樣可以使用統一的 LLM 優先順序策略(Groq -> Ollama -> MLX)
1260
- from ..utils.llm_utils import get_llm
1261
- from langchain_core.messages import HumanMessage
1262
-
1263
- # 導入 LLM 適配器和智能選擇器
1264
- from .llm_adapter import LangChainLLMAdapter
1265
- from .adaptive_rag_selector import AdaptiveRAGSelector, RAGMethod
1266
-
1267
- # 添加 Learn_RAG 到 Python 路徑
1268
- # 計算 Learn_RAG 的路徑(與 Deep_Agentic_AI_Tool 在同一目錄下)
1269
- current_file = Path(__file__).resolve()
1270
- # 從 deep_agent_rag/rag/private_file_rag.py 向上找到 Deep_Agentic_AI_Tool 根目錄
1271
- deep_agent_root = current_file.parent.parent.parent.parent
1272
- learn_rag_path = deep_agent_root.parent / "Learn_RAG"
1273
-
1274
- # 如果 Learn_RAG 不在預期位置,嘗試其他可能的位置
1275
- if not learn_rag_path.exists():
1276
- # 嘗試當前工作目錄的父目錄
1277
- cwd = Path.cwd()
1278
- learn_rag_path = cwd.parent / "Learn_RAG"
1279
-
1280
- if not learn_rag_path.exists():
1281
- # 嘗試直接使用絕對路徑
1282
- learn_rag_path = Path("/Users/matthuang/Desktop/Learn_RAG")
1283
-
1284
- # 將 Learn_RAG 目錄添加到 Python 路徑(這樣可以導入 src 模組)
1285
- # 注意:需要將 Learn_RAG 目錄本身添加到路徑,因為 src 模組在 Learn_RAG/src/ 下
1286
- if learn_rag_path.exists() and learn_rag_path.is_dir():
1287
- if str(learn_rag_path) not in sys.path:
1288
- sys.path.insert(0, str(learn_rag_path))
1289
- print(f"✓ 找到 Learn_RAG 項目: {learn_rag_path}")
1290
- print(f" Python 路徑已添加: {learn_rag_path}")
1291
- else:
1292
- print(f"⚠️ 無法找到 Learn_RAG 項目")
1293
- print(f" 嘗試的路徑: {learn_rag_path}")
1294
- print(f" 請確保 Learn_RAG 項目在: {deep_agent_root.parent / 'Learn_RAG'}")
1295
-
1296
- # 嘗試導入 Learn_RAG 模組
1297
- # 注意:document_processor.py 在頂層導入了 arxiv,所以需要先安裝依賴
1298
- try:
1299
- # 先檢查必要的依賴是否已安裝
1300
- import importlib
1301
-
1302
- required_deps = {
1303
- "arxiv": "arxiv",
1304
- "langchain_community": "langchain-community",
1305
- "langchain_text_splitters": "langchain-text-splitters",
1306
- "chromadb": "chromadb",
1307
- "sentence_transformers": "sentence-transformers",
1308
- "rank_bm25": "rank-bm25",
1309
- "pypdf": "pypdf",
1310
- }
1311
-
1312
- missing_deps = []
1313
- for module_name, package_name in required_deps.items():
1314
- try:
1315
- importlib.import_module(module_name)
1316
- except ImportError:
1317
- missing_deps.append(package_name)
1318
-
1319
- if missing_deps:
1320
- print(f"⚠️ 缺少以下依賴包: {', '.join(missing_deps)}")
1321
- print(f"\n💡 請安裝 Learn_RAG 項目的依賴:")
1322
- print(f" 方法 1: 使用 pip")
1323
- print(f" pip install {' '.join(missing_deps)}")
1324
- print(f"\n 方法 2: 使用 uv (推薦,如果 Learn_RAG 使用 uv)")
1325
- print(f" cd {learn_rag_path}")
1326
- print(f" uv sync")
1327
- print(f"\n 方法 3: 安裝所有 Learn_RAG 依賴")
1328
- print(f" pip install arxiv langchain-community langchain-text-splitters chromadb sentence-transformers rank-bm25 pypdf docx2txt langchain-experimental")
1329
- LEARN_RAG_AVAILABLE = False
1330
- else:
1331
- # 所有依賴都已安裝,嘗試導入模組
1332
- from src.document_processor import DocumentProcessor
1333
- from src.retrievers.bm25_retriever import BM25Retriever
1334
- from src.retrievers.vector_retriever import VectorRetriever
1335
- from src.retrievers.hybrid_search import HybridSearch
1336
- from src.retrievers.reranker import Reranker, RAGPipeline
1337
- from src.prompt_formatter import PromptFormatter
1338
- # 導入進階 RAG 方法
1339
- from src.subquery_rag import SubQueryDecompositionRAG
1340
- from src.hyde_rag import HyDERAG
1341
- from src.step_back_rag import StepBackRAG
1342
- from src.hybrid_subquery_hyde_rag import HybridSubqueryHyDERAG
1343
- from src.triple_hybrid_rag import TripleHybridRAG
1344
- # 不再需要導入 OllamaLLM,因為我們使用 Deep_Agentic_AI_Tool 的統一 LLM 系統(get_llm())
1345
- # from src.llm_integration import OllamaLLM
1346
- LEARN_RAG_AVAILABLE = True
1347
- print("✓ 成功導入 Learn_RAG 模組(包含進階 RAG 方法)")
1348
-
1349
- except ImportError as e:
1350
- error_msg = str(e)
1351
- print(f"⚠️ 無法導入 Learn_RAG 模組: {error_msg}")
1352
- print(f"\n💡 請安裝 Learn_RAG 項目的依賴:")
1353
- print(f" pip install arxiv langchain-community langchain-text-splitters chromadb sentence-transformers rank-bm25 pypdf docx2txt langchain-experimental")
1354
- print(f"\n 或者:")
1355
- print(f" cd {learn_rag_path}")
1356
- print(f" uv sync")
1357
- LEARN_RAG_AVAILABLE = False
1358
- except Exception as e:
1359
- error_msg = str(e)
1360
- print(f"⚠️ 導入 Learn_RAG 模組時發生錯誤: {error_msg}")
1361
- print(f" 當前 Python 路徑: {sys.path[:3]}")
1362
- print(f" Learn_RAG 路徑: {learn_rag_path}")
1363
- LEARN_RAG_AVAILABLE = False
1364
-
1365
-
1366
- class PrivateFileRAG:
1367
  """
1368
  私有文件 RAG 系統管理器
1369
 
 
24
  from .llm_adapter import LangChainLLMAdapter
25
  from .adaptive_rag_selector import AdaptiveRAGSelector, RAGMethod
26
 
27
+ # 添加項目根目錄到 Python 路徑(這樣可以導入 src 模組)
28
+ # deep_agent_rag/rag/private_file_rag.py 向上找到 Deep_Agentic_AI_Tool 目錄
29
  current_file = Path(__file__).resolve()
30
  # 從 deep_agent_rag/rag/private_file_rag.py 向上找到 Deep_Agentic_AI_Tool 根目錄
31
+ # private_file_rag.py -> rag/ -> deep_agent_rag/ -> Deep_Agentic_AI_Tool/
32
  deep_agent_root = current_file.parent.parent.parent.parent
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # 檢查 src 目錄是否存在應該在項目根目錄下
35
+ src_path = deep_agent_root / "src"
36
+ if src_path.exists() and src_path.is_dir():
37
+ # 將項目根目錄添加到 Python 路徑(不是 src 目錄本身)
38
+ # 這樣可以通過 from src.xxx import xxx 導入
39
+ if str(deep_agent_root) not in sys.path:
40
+ sys.path.insert(0, str(deep_agent_root))
41
+ print(f"✓ 找到本地 src 模組: {src_path}")
42
+ print(f" 項目根目錄已添加到 Python 路徑: {deep_agent_root}")
43
  else:
44
+ print(f"⚠️ 無法找到 src")
45
+ print(f" 預期路徑: {src_path}")
46
+ print(f" 項目根目錄: {deep_agent_root}")
47
 
48
  # 嘗試導入 Learn_RAG 模組
49
  # 注意:document_processor.py 在頂層導入了 arxiv,所以需要先安裝依賴
 
70
 
71
  if missing_deps:
72
  print(f"⚠️ 缺少以下依賴包: {', '.join(missing_deps)}")
73
+ print(f"\n💡 請安裝 RAG 系統所需的依賴:")
74
  print(f" 方法 1: 使用 pip")
75
  print(f" pip install {' '.join(missing_deps)}")
76
+ print(f"\n 方法 2: 使用 uv (推薦)")
77
+ print(f" cd {deep_agent_root}")
78
  print(f" uv sync")
79
+ print(f"\n 方法 3: 安裝所有依賴")
80
  print(f" pip install arxiv langchain-community langchain-text-splitters chromadb sentence-transformers rank-bm25 pypdf docx2txt langchain-experimental")
81
  LEARN_RAG_AVAILABLE = False
82
  else:
 
96
  # 不再需要導入 OllamaLLM,因為我們使用 Deep_Agentic_AI_Tool 的統一 LLM 系統(get_llm())
97
  # from src.llm_integration import OllamaLLM
98
  LEARN_RAG_AVAILABLE = True
99
+ print("✓ 成功導入 RAG 模組(本地集成版本,包含進階 RAG 方法)")
100
 
101
  except ImportError as e:
102
  error_msg = str(e)
103
+ print(f"⚠️ 無法導入 RAG 模組: {error_msg}")
104
+ print(f"\n💡 請安裝 RAG 系統所需的依賴:")
105
  print(f" pip install arxiv langchain-community langchain-text-splitters chromadb sentence-transformers rank-bm25 pypdf docx2txt langchain-experimental")
106
  print(f"\n 或者:")
107
+ print(f" cd {deep_agent_root}")
108
  print(f" uv sync")
109
  LEARN_RAG_AVAILABLE = False
110
  except Exception as e:
111
  error_msg = str(e)
112
+ print(f"⚠️ 導入 RAG 模組時發生錯誤: {error_msg}")
113
  print(f" 當前 Python 路徑: {sys.path[:3]}")
114
+ print(f" 項目根目錄: {deep_agent_root}")
115
+ print(f" src 目錄: {src_path}")
116
  LEARN_RAG_AVAILABLE = False
117
 
118
 
 
1229
  """重置全局實例"""
1230
  global _private_rag_instance
1231
  _private_rag_instance = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1232
  """
1233
  私有文件 RAG 系統管理器
1234
 
src/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG 系統模組套件
3
+ """
4
+ from .document_processor import DocumentProcessor
5
+ from .retrievers import (
6
+ BaseRetriever,
7
+ BM25Retriever,
8
+ VectorRetriever,
9
+ HybridSearch,
10
+ Reranker,
11
+ RAGPipeline,
12
+ )
13
+ from .prompt_formatter import PromptFormatter
14
+ from .llm_integration import OllamaLLM
15
+ from .subquery_rag import SubQueryDecompositionRAG
16
+ from .hyde_rag import HyDERAG
17
+ from .hybrid_subquery_hyde_rag import HybridSubqueryHyDERAG
18
+ from .step_back_rag import StepBackRAG
19
+ from .triple_hybrid_rag import TripleHybridRAG
20
+
21
+ __all__ = [
22
+ "DocumentProcessor",
23
+ "BaseRetriever",
24
+ "BM25Retriever",
25
+ "VectorRetriever",
26
+ "HybridSearch",
27
+ "Reranker",
28
+ "RAGPipeline",
29
+ "PromptFormatter",
30
+ "OllamaLLM",
31
+ "SubQueryDecompositionRAG",
32
+ "HyDERAG",
33
+ "HybridSubqueryHyDERAG",
34
+ "StepBackRAG",
35
+ "TripleHybridRAG",
36
+ ]
37
+
src/document_processor.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 文檔處理模組:載入 arXiv 論文並進行文字分割
3
+ 支援本地檔案:PDF, DOCX, TXT
4
+ 支援兩種分塊策略:
5
+ 1. 字符分塊(預設):基於固定字符數的分塊,速度快
6
+ 2. 語義分塊(可選):基於語義相似度的分塊,能保持語義完整性
7
+ """
8
+ from typing import List, Dict, Optional, Any
9
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
10
+ from pathlib import Path
11
+ import os
12
+ import arxiv
13
+ import re
14
+
15
+ # 嘗試導入語義分塊器(需要 langchain-experimental)
16
+ try:
17
+ from langchain_experimental.text_splitter import SemanticChunker
18
+ SEMANTIC_CHUNKER_AVAILABLE = True
19
+ except ImportError:
20
+ SEMANTIC_CHUNKER_AVAILABLE = False
21
+
22
+
23
+ class DocumentProcessor:
24
+ """
25
+ 處理 arXiv 論文文檔,進行分割和準備
26
+
27
+ 支援兩種分塊模式:
28
+ - 字符分塊(預設):快速、穩定,適合大多數場景
29
+ - 語義分塊(可選):更智能,能保持語義完整性,但需要額外依賴和計算時間
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ chunk_size: int = 1000,
35
+ chunk_overlap: int = 200,
36
+ embeddings: Optional[Any] = None, # 可選:用於語義分塊的 embedding 模型
37
+ use_semantic_chunking: bool = False, # 是否使用語義分塊
38
+ breakpoint_threshold_amount: float = 1.5, # 語義分塊敏感度(標準差倍數)
39
+ min_chunk_size: int = 100 # 語義分塊的最小 chunk 大小(字符數)
40
+ ):
41
+ """
42
+ 初始化文檔處理器
43
+
44
+ Args:
45
+ chunk_size: 每個 chunk 的大小(字符數),僅用於字符分塊模式
46
+ chunk_overlap: chunk 之間的重疊大小(字符數),僅用於字符分塊模式
47
+ embeddings: 用於計算語義距離的 embedding 模型物件(可選)
48
+ 當 use_semantic_chunking=True 時必須提供
49
+ use_semantic_chunking: 是否使用語義分塊
50
+ True: 使用語義分塊(需要提供 embeddings)
51
+ False: 使用字符分塊(預設)
52
+ breakpoint_threshold_amount: 語義分塊的敏感度參數
53
+ 數值越大,分塊越少(chunks 越大)
54
+ 數值越小,分塊越多(chunks 越小)
55
+ 建議範圍:1.0 - 2.0,預設 1.5
56
+ min_chunk_size: 語義分塊的最小 chunk 大小(字符數)
57
+ 小於此大小的 chunks 會被合併到相鄰的 chunks
58
+ 預設 100 字符
59
+ """
60
+ self.embeddings = embeddings
61
+ self.use_semantic_chunking = use_semantic_chunking
62
+ self.min_chunk_size = min_chunk_size
63
+
64
+ # 如果要求使用語義分塊
65
+ if use_semantic_chunking:
66
+ # 檢查是否安裝了必要的依賴
67
+ if not SEMANTIC_CHUNKER_AVAILABLE:
68
+ raise ImportError(
69
+ "使用語義分塊需要安裝 langchain-experimental 套件。\n"
70
+ "請執行: pip install langchain-experimental\n"
71
+ "或使用字符分塊模式(use_semantic_chunking=False)"
72
+ )
73
+
74
+ # 檢查是否提供了 embeddings
75
+ if embeddings is None:
76
+ raise ValueError(
77
+ "使用語義分塊時必須提供 embeddings 參數。\n"
78
+ "範例:\n"
79
+ " from langchain_community.embeddings import HuggingFaceEmbeddings\n"
80
+ " embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')\n"
81
+ " processor = DocumentProcessor(embeddings=embeddings, use_semantic_chunking=True)"
82
+ )
83
+
84
+ # 初始化語義分塊器
85
+ # 使用「標準差」策略:當相鄰句子之間的語義距離超過平均距離的標準差倍數時,進行切分
86
+ self.text_splitter = SemanticChunker(
87
+ embeddings,
88
+ breakpoint_threshold_type="standard_deviation",
89
+ breakpoint_threshold_amount=breakpoint_threshold_amount
90
+ )
91
+ print(f"✓ 使用語義分塊模式(敏感度: {breakpoint_threshold_amount},最小 chunk 大小: {min_chunk_size} 字符)")
92
+ else:
93
+ # 使用傳統的字符分塊(預設模式)
94
+ self.text_splitter = RecursiveCharacterTextSplitter(
95
+ chunk_size=chunk_size,
96
+ chunk_overlap=chunk_overlap,
97
+ length_function=len,
98
+ )
99
+ print(f"✓ 使用字符分塊模式(大小: {chunk_size} 字符,重疊: {chunk_overlap} 字符)")
100
+
101
+ def _post_process_chunks(self, chunks: List[str]) -> List[str]:
102
+ """
103
+ 後處理 chunks:過濾和合併太小的 chunks
104
+
105
+ 語義分塊可能會產生一些非��小的 chunks(例如只有幾個單詞),
106
+ 這些小 chunks 可能不包含足夠的上下文資訊。此方法會:
107
+ 1. 將小於 min_chunk_size 的 chunks 合併到相鄰的 chunks
108
+ 2. 確保最終的 chunks 都有足夠的大小
109
+
110
+ Args:
111
+ chunks: 原始 chunks 列表(從分塊器產生的)
112
+
113
+ Returns:
114
+ 處理後的 chunks 列表(過濾和合併後的)
115
+ """
116
+ # 如果使用字符分塊,不需要後處理(因為已經有固定大小)
117
+ if not self.use_semantic_chunking:
118
+ return chunks
119
+
120
+ # 如果沒有 chunks,直接返回
121
+ if not chunks:
122
+ return chunks
123
+
124
+ processed = []
125
+ current_small_chunk = "" # 累積的小 chunk
126
+
127
+ for chunk in chunks:
128
+ chunk_stripped = chunk.strip()
129
+ chunk_length = len(chunk_stripped)
130
+
131
+ # 如果當前 chunk 太小,嘗試與下一個合併
132
+ if chunk_length < self.min_chunk_size:
133
+ # 累積到臨時變數中
134
+ if current_small_chunk:
135
+ current_small_chunk += "\n\n" + chunk
136
+ else:
137
+ current_small_chunk = chunk
138
+ else:
139
+ # 當前 chunk 足夠大
140
+ # 如果有累積的小 chunk,先處理它
141
+ if current_small_chunk:
142
+ current_small_chunk_stripped = current_small_chunk.strip()
143
+ if len(current_small_chunk_stripped) >= self.min_chunk_size:
144
+ # 累積後足夠大,作為獨立 chunk
145
+ processed.append(current_small_chunk)
146
+ else:
147
+ # 累積後還是太小,合併到上一個 chunk(如果存在)
148
+ if processed:
149
+ processed[-1] += "\n\n" + current_small_chunk
150
+ else:
151
+ # 如果沒有上一個 chunk,還是要保留
152
+ processed.append(current_small_chunk)
153
+ current_small_chunk = ""
154
+
155
+ # 添加當前足夠大的 chunk
156
+ processed.append(chunk)
157
+
158
+ # 處理最後的累積小 chunk
159
+ if current_small_chunk:
160
+ current_small_chunk_stripped = current_small_chunk.strip()
161
+ if len(current_small_chunk_stripped) >= self.min_chunk_size:
162
+ # 足夠大,作為獨立 chunk
163
+ processed.append(current_small_chunk)
164
+ elif processed:
165
+ # 太小,合併到最後一個 chunk
166
+ processed[-1] += "\n\n" + current_small_chunk
167
+ else:
168
+ # 如果沒有其他 chunks,還是要保留
169
+ processed.append(current_small_chunk)
170
+
171
+ return processed
172
+
173
+ def fetch_papers(self, query: str, max_results: int = 10) -> List[Dict]:
174
+ """
175
+ 從 arXiv 獲取論文
176
+
177
+ Args:
178
+ query: 搜尋查詢(例如 "cat:cs.AI")
179
+ max_results: 最大結果數量
180
+
181
+ Returns:
182
+ 論文列表,每個論文包含標題、摘要等資訊
183
+ """
184
+ search = arxiv.Search(
185
+ query=query,
186
+ max_results=max_results,
187
+ sort_by=arxiv.SortCriterion.SubmittedDate
188
+ )
189
+
190
+ papers = []
191
+ for paper in search.results():
192
+ papers.append({
193
+ "title": paper.title,
194
+ "authors": [author.name for author in paper.authors],
195
+ "summary": paper.summary,
196
+ "published": str(paper.published),
197
+ "arxiv_id": paper.entry_id.split('/')[-1],
198
+ "arxiv_url": paper.entry_id,
199
+ "pdf_url": paper.pdf_url,
200
+ "categories": paper.categories,
201
+ })
202
+
203
+ return papers
204
+
205
+ def process_documents(self, papers: List[Dict]) -> List[Dict]:
206
+ """
207
+ 處理論文,將每篇論文分割成 chunks
208
+
209
+ Args:
210
+ papers: 論文列表
211
+
212
+ Returns:
213
+ 處理後的文檔 chunks,每個 chunk 包含內容和元數據
214
+ """
215
+ documents = []
216
+
217
+ for paper in papers:
218
+ # 組合論文的完整文字(標題 + 摘要)
219
+ # 保留換行符號 \n\n 作為語義斷點的結構參考
220
+ full_text = f"Title: {paper['title']}\n\nAbstract: {paper['summary']}"
221
+
222
+ # 分割文字(根據選擇的模式:字符分塊或語義分塊)
223
+ chunks = self.text_splitter.split_text(full_text)
224
+
225
+ # 後處理:過濾和合併太小的 chunks(僅語義分塊模式)
226
+ chunks = self._post_process_chunks(chunks)
227
+
228
+ # 為每個 chunk 創建文檔物件
229
+ for i, chunk in enumerate(chunks):
230
+ doc = {
231
+ "content": chunk,
232
+ "metadata": {
233
+ "title": paper['title'],
234
+ "arxiv_id": paper['arxiv_id'],
235
+ "arxiv_url": paper['arxiv_url'],
236
+ "pdf_url": paper['pdf_url'],
237
+ "authors": paper['authors'],
238
+ "published": paper['published'],
239
+ "categories": paper['categories'],
240
+ "chunk_index": i,
241
+ "total_chunks": len(chunks),
242
+ "chunking_method": "semantic" if self.use_semantic_chunking else "character"
243
+ }
244
+ }
245
+ documents.append(doc)
246
+
247
+ return documents
248
+
249
+ def get_texts_and_metadatas(self, documents: List[Dict]):
250
+ """
251
+ 從文檔列表中提取文字和元數據
252
+
253
+ Args:
254
+ documents: 文檔列表
255
+
256
+ Returns:
257
+ (texts, metadatas) 元組
258
+ """
259
+ texts = [doc["content"] for doc in documents]
260
+ metadatas = [doc["metadata"] for doc in documents]
261
+ return texts, metadatas
262
+
263
+ @staticmethod
264
+ def clean_extracted_text(text: str) -> str:
265
+ """
266
+ 清理從 PDF/DOCX 提取的文本,移除多餘的空格和修復字符換行問題
267
+
268
+ 某些 PDF 提取工具會在每個字符之間插入空格或換行,特別是中文文本。
269
+ 此方法會:
270
+ 1. 修復「每個字符一行」的問題(將單字符行合併)
271
+ 2. 移除中文字符之間的多餘空格
272
+ 3. 保留英文單詞之間的空格
273
+ 4. 保留標點符號周圍的適當空格
274
+ 5. 保留真正的段落分隔
275
+
276
+ Args:
277
+ text: 原始提取的文本
278
+
279
+ Returns:
280
+ 清理後的文本
281
+ """
282
+ if not text:
283
+ return text
284
+
285
+ # 步驟 0: 修復「每個字符一行」的問題
286
+ # 檢測模式:每行只有一個字符(可能是中文字符、標點、或單個字母/數字)
287
+ # 將這些單字符行合併成連續文本
288
+ lines = text.split('\n')
289
+ merged_lines = []
290
+ i = 0
291
+
292
+ def is_single_char_line(line: str) -> bool:
293
+ """
294
+ 判斷是否為單字符行
295
+ 考慮:去除空格後長度 <= 3(可能是單字符+標點,或單字符+空格)
296
+ """
297
+ stripped = line.strip()
298
+ if not stripped:
299
+ return False # 空行不算
300
+ # 如果去除空格後長度 <= 3,且主要是中文字符、標點或單個字母/數字
301
+ if len(stripped) <= 3:
302
+ # 檢查是否主要是單個字符(可能帶標點或空格)
303
+ # 移除所有空格後,如果長度 <= 2,認為是單字符行
304
+ no_space = stripped.replace(' ', '')
305
+ if len(no_space) <= 2:
306
+ return True
307
+ return False
308
+
309
+ while i < len(lines):
310
+ line = lines[i]
311
+ stripped_line = line.strip()
312
+
313
+ # 如果當前行是單字符行
314
+ if is_single_char_line(line):
315
+ # 收集連續的單字符行(包括空行,因為空行可能是分隔符)
316
+ merged_chars = []
317
+ j = i
318
+ consecutive_single_chars = 0
319
+
320
+ while j < len(lines):
321
+ current_line = lines[j]
322
+ current_stripped = current_line.strip()
323
+
324
+ if is_single_char_line(current_line):
325
+ # 是單字符行,收集字符(去除空格)
326
+ char = current_stripped.replace(' ', '')
327
+ if char:
328
+ merged_chars.append(char)
329
+ consecutive_single_chars += 1
330
+ j += 1
331
+ elif not current_stripped:
332
+ # 空行:如果前面有單字符,且後面可能還有單字符,跳過空行
333
+ # 檢查下一行是否也是單字符
334
+ if j + 1 < len(lines) and is_single_char_line(lines[j + 1]):
335
+ # 空行後面還有單字符,跳過空行繼續收集
336
+ j += 1
337
+ else:
338
+ # 空行後面沒有單字符了,停止收集
339
+ break
340
+ else:
341
+ # 遇到正常行,停止收集
342
+ break
343
+
344
+ # 如果收集到多個單字符,合併它們
345
+ if len(merged_chars) > 1:
346
+ merged_text = ''.join(merged_chars)
347
+ merged_lines.append(merged_text)
348
+ i = j
349
+ continue
350
+ elif len(merged_chars) == 1 and consecutive_single_chars > 1:
351
+ # 只有一個字符但有多行(可能是空格導致的),也合併
352
+ merged_text = ''.join(merged_chars)
353
+ merged_lines.append(merged_text)
354
+ i = j
355
+ continue
356
+ else:
357
+ # 只有一個單字符,且確實只有一行,保留原樣
358
+ if merged_chars:
359
+ merged_lines.append(merged_chars[0])
360
+ i = j
361
+ continue
362
+ else:
363
+ # 正常行,直接添加
364
+ if stripped_line: # 非空行
365
+ merged_lines.append(stripped_line)
366
+ i += 1
367
+
368
+ # 重新組合文本
369
+ text = '\n'.join(merged_lines)
370
+
371
+ # 步驟 0.5: 再次處理可能的殘留問題
372
+ # 如果還有單字符行(可能是第一次處理遺漏的),再次處理
373
+ lines = text.split('\n')
374
+ final_lines = []
375
+ i = 0
376
+ while i < len(lines):
377
+ line = lines[i].strip()
378
+ if is_single_char_line(line):
379
+ # 再次收集連續的單字符行
380
+ merged_chars = []
381
+ j = i
382
+ while j < len(lines) and is_single_char_line(lines[j]):
383
+ char = lines[j].strip().replace(' ', '')
384
+ if char:
385
+ merged_chars.append(char)
386
+ j += 1
387
+
388
+ if len(merged_chars) > 1:
389
+ final_lines.append(''.join(merged_chars))
390
+ i = j
391
+ else:
392
+ if merged_chars:
393
+ final_lines.append(merged_chars[0])
394
+ i = j
395
+ else:
396
+ if line:
397
+ final_lines.append(line)
398
+ i += 1
399
+
400
+ text = '\n'.join(final_lines)
401
+
402
+ # 1. 移除中文字符之間的空格
403
+ # 匹配模式:中文字符 + 空格 + 中文字符
404
+ chinese_char_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])'
405
+ text = re.sub(chinese_char_pattern, r'\1\2', text)
406
+
407
+ # 2. 移除中文和標點符號之間的多餘空格
408
+ # 中文 + 空格 + 標點符號
409
+ chinese_punct_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([,。、;:!?""''()【】《》])'
410
+ text = re.sub(chinese_punct_pattern, r'\1\2', text)
411
+
412
+ # 標點符號 + 空格 + 中文
413
+ # 使用 re.escape 來正確處理標點符號,避免轉義序列警告
414
+ punct_chars = ',。、;:!?""''()【】《》'
415
+ punct_chinese_pattern = f'([{re.escape(punct_chars)}])\\s+([\\u4e00-\\u9fff\\u3400-\\u4dbf\\uf900-\\ufaff])'
416
+ text = re.sub(punct_chinese_pattern, r'\1\2', text)
417
+
418
+ # 3. 移除數字和中文之間的多餘空格(例如:"500 公里" -> "500公里")
419
+ number_chinese_pattern = r'(\d+)\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])'
420
+ text = re.sub(number_chinese_pattern, r'\1\2', text)
421
+ chinese_number_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+(\d+)'
422
+ text = re.sub(chinese_number_pattern, r'\1\2', text)
423
+
424
+ # 4. 移除英文單詞內部的多餘空格(例如:"Nebula-X 跨次 元量" -> "Nebula-X 跨次元量")
425
+ # 但保留英文單詞之間的空格
426
+ # 匹配:非空格字符 + 空格 + 非空格字符(如果其中一個是中文,則移除空格)
427
+ mixed_space_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])'
428
+ text = re.sub(mixed_space_pattern, r'\1\2', text)
429
+
430
+ # 5. 移除多個連續空格(保留單個空格,用於英文單詞之間)
431
+ text = re.sub(r' +', ' ', text)
432
+
433
+ # 6. 清理行首行尾的空格(但保留換行符)
434
+ lines = text.split('\n')
435
+ cleaned_lines = [line.strip() for line in lines]
436
+ text = '\n'.join(cleaned_lines)
437
+
438
+ # 7. 移除多個連續的換行符(保留最多兩個,用於段落分隔)
439
+ text = re.sub(r'\n{3,}', '\n\n', text)
440
+
441
+ # 8. 修復可能的殘留問題:移除中文字符之間殘留的空格
442
+ # 再次檢查並移除中文字符之間的空格(處理可能遺漏的情況)
443
+ text = re.sub(r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])', r'\1\2', text)
444
+
445
+ return text
446
+
447
+ def load_from_file(self, file_path: str) -> Dict:
448
+ """
449
+ 從本地檔案載入文檔(支援 PDF, DOCX, TXT 等)
450
+
451
+ Args:
452
+ file_path: 檔案路徑
453
+
454
+ Returns:
455
+ 文檔字典,包含內容和元數據
456
+ """
457
+ file_path = Path(file_path)
458
+
459
+ if not file_path.exists():
460
+ raise FileNotFoundError(f"檔案不存在: {file_path}")
461
+
462
+ file_ext = file_path.suffix.lower()
463
+ file_name = file_path.stem
464
+ file_size = os.path.getsize(file_path)
465
+
466
+ # 根據檔案類型選擇不同的加載器
467
+ if file_ext == '.pdf':
468
+ try:
469
+ from langchain_community.document_loaders import PyPDFLoader
470
+ loader = PyPDFLoader(str(file_path))
471
+ pages = loader.load()
472
+ # 合併所有頁面
473
+ full_text = "\n\n".join([page.page_content for page in pages])
474
+ # 清理提取的文本(移除多餘空格)
475
+ full_text = self.clean_extracted_text(full_text)
476
+ except ImportError:
477
+ raise ImportError(
478
+ "需要安裝 pypdf 來處理 PDF 檔案: pip install pypdf"
479
+ )
480
+
481
+ elif file_ext in ['.docx', '.doc']:
482
+ try:
483
+ from langchain_community.document_loaders import Docx2txtLoader
484
+ loader = Docx2txtLoader(str(file_path))
485
+ pages = loader.load()
486
+ full_text = "\n\n".join([page.page_content for page in pages])
487
+ # 清理提取的文本(移除多餘空格)
488
+ full_text = self.clean_extracted_text(full_text)
489
+ except ImportError:
490
+ raise ImportError(
491
+ "需要安裝 docx2txt 來處理 DOCX 檔案: pip install docx2txt"
492
+ )
493
+
494
+ elif file_ext == '.txt':
495
+ # 嘗試不同的編碼
496
+ encodings = ['utf-8', 'gbk', 'big5', 'latin-1']
497
+ full_text = None
498
+ for encoding in encodings:
499
+ try:
500
+ with open(file_path, 'r', encoding=encoding) as f:
501
+ full_text = f.read()
502
+ break
503
+ except UnicodeDecodeError:
504
+ continue
505
+
506
+ if full_text is None:
507
+ raise ValueError(f"無法讀取檔案,嘗試的編碼都不適用: {encodings}")
508
+
509
+ else:
510
+ raise ValueError(
511
+ f"不支援的檔案類型: {file_ext}\n"
512
+ f"支援的格式: .pdf, .docx, .doc, .txt"
513
+ )
514
+
515
+ if not full_text or len(full_text.strip()) == 0:
516
+ raise ValueError(f"檔案為空或無法提取文字: {file_path}")
517
+
518
+ return {
519
+ "title": file_name,
520
+ "content": full_text,
521
+ "file_path": str(file_path),
522
+ "file_type": file_ext,
523
+ "file_size": file_size,
524
+ }
525
+
526
+ def process_file(self, file_path: str) -> List[Dict]:
527
+ """
528
+ 處理單個檔案,分割成 chunks
529
+
530
+ Args:
531
+ file_path: 檔案路徑
532
+
533
+ Returns:
534
+ 處理後的文檔 chunks 列表
535
+ """
536
+ # 載入檔案
537
+ file_doc = self.load_from_file(file_path)
538
+
539
+ # 分割文字(根據選擇的模式:字符分塊或語義分塊)
540
+ chunks = self.text_splitter.split_text(file_doc["content"])
541
+
542
+ # 後處理:過濾和合併太小的 chunks(僅語義分塊模式)
543
+ chunks = self._post_process_chunks(chunks)
544
+
545
+ if not chunks:
546
+ raise ValueError(f"檔案分割後沒有內容: {file_path}")
547
+
548
+ # 創建文檔 chunks
549
+ documents = []
550
+ for i, chunk in enumerate(chunks):
551
+ doc = {
552
+ "content": chunk,
553
+ "metadata": {
554
+ "title": file_doc["title"],
555
+ "file_path": file_doc["file_path"],
556
+ "file_type": file_doc["file_type"],
557
+ "file_size": file_doc["file_size"],
558
+ "chunk_index": i,
559
+ "total_chunks": len(chunks),
560
+ "chunking_method": "semantic" if self.use_semantic_chunking else "character"
561
+ }
562
+ }
563
+ documents.append(doc)
564
+
565
+ return documents
566
+
567
+ def process_files(self, file_paths: List[str]) -> List[Dict]:
568
+ """
569
+ 處理多個檔案
570
+
571
+ Args:
572
+ file_paths: 檔案路徑列表
573
+
574
+ Returns:
575
+ 所有檔案的文檔 chunks 列表
576
+ """
577
+ all_documents = []
578
+ for file_path in file_paths:
579
+ try:
580
+ print(f"處理檔案: {file_path}")
581
+ documents = self.process_file(file_path)
582
+ all_documents.extend(documents)
583
+ print(f" ✓ 創建了 {len(documents)} 個 chunks")
584
+ except Exception as e:
585
+ print(f" ✗ 處理檔案失敗: {file_path}")
586
+ print(f" 錯誤: {e}")
587
+ continue
588
+
589
+ return all_documents
590
+
src/hybrid_subquery_hyde_rag.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid Sub-query + HyDE RAG:融合 Sub-query Decomposition 和 HyDE
3
+ 結合兩種方法的優勢,提升檢索精度
4
+ """
5
+ from typing import List, Dict, Optional
6
+ from .retrievers.reranker import RAGPipeline
7
+ from .retrievers.vector_retriever import VectorRetriever
8
+ from .prompt_formatter import PromptFormatter
9
+ from .llm_integration import OllamaLLM
10
+ import hashlib
11
+ import time
12
+ import logging
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class HybridSubqueryHyDERAG:
19
+ """融合 Sub-query Decomposition 和 HyDE 的 RAG 系統"""
20
+
21
+ def __init__(
22
+ self,
23
+ rag_pipeline: RAGPipeline,
24
+ vector_retriever: VectorRetriever,
25
+ llm: OllamaLLM,
26
+ max_sub_queries: int = 3,
27
+ top_k_per_subquery: int = 5,
28
+ hypothetical_length: int = 200,
29
+ temperature_subquery: float = 0.3,
30
+ temperature_hyde: float = 0.7,
31
+ enable_parallel: bool = True
32
+ ):
33
+ """
34
+ 初始化融合 RAG
35
+
36
+ Args:
37
+ rag_pipeline: RAG 管線實例
38
+ vector_retriever: 向量檢索器
39
+ llm: LLM 實例
40
+ max_sub_queries: 最多生成的子問題數量
41
+ top_k_per_subquery: 每個子問題檢索的結果數量
42
+ hypothetical_length: 假設性文檔目標長度(字符數)
43
+ temperature_subquery: 生成子問題的溫度(較低,更穩定)
44
+ temperature_hyde: 生成假設性文檔的溫度(較高,更多專業術語)
45
+ enable_parallel: 是否並行處理
46
+ """
47
+ self.rag_pipeline = rag_pipeline
48
+ self.vector_retriever = vector_retriever
49
+ self.llm = llm
50
+ self.max_sub_queries = max_sub_queries
51
+ self.top_k_per_subquery = top_k_per_subquery
52
+ self.hypothetical_length = hypothetical_length
53
+ self.temperature_subquery = temperature_subquery
54
+ self.temperature_hyde = temperature_hyde
55
+ self.enable_parallel = enable_parallel
56
+
57
+ def _generate_sub_queries(self, question: str) -> List[str]:
58
+ """
59
+ 生成子問題(與 SubQueryDecompositionRAG 相同)
60
+
61
+ Args:
62
+ question: 原始問題
63
+
64
+ Returns:
65
+ 子問題列表
66
+ """
67
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
68
+
69
+ if is_chinese:
70
+ prompt = f"""你是一個專業助理。請將以下原始問題拆解成最多 {self.max_sub_queries} 個具體的子問題,以便進行資料搜尋。
71
+ 每個子問題應專注於原始問題的一個特定面向。請以換行符號分隔問題。
72
+
73
+ 原始問題: {question}
74
+
75
+ 子問題清單:"""
76
+ else:
77
+ prompt = f"""You are a professional assistant. Please decompose the following original question into at most {self.max_sub_queries} specific sub-questions for information retrieval.
78
+ Each sub-question should focus on a specific aspect of the original question. Please separate questions with newlines.
79
+
80
+ Original question: {question}
81
+
82
+ Sub-question list:"""
83
+
84
+ try:
85
+ response = self.llm.generate(
86
+ prompt=prompt,
87
+ temperature=self.temperature_subquery,
88
+ max_tokens=500
89
+ )
90
+
91
+ sub_queries = [
92
+ q.strip()
93
+ for q in response.strip().split("\n")
94
+ if q.strip() and not q.strip().startswith("#")
95
+ ]
96
+
97
+ # 移除編號前綴(如 "1. ", "1) " 等)
98
+ cleaned_queries = []
99
+ for q in sub_queries:
100
+ q = q.lstrip("0123456789. )")
101
+ q = q.strip()
102
+ if q:
103
+ cleaned_queries.append(q)
104
+
105
+ cleaned_queries = cleaned_queries[:self.max_sub_queries]
106
+
107
+ if not cleaned_queries:
108
+ logger.warning("⚠️ 未生成子問題,使用原始問題")
109
+ cleaned_queries = [question]
110
+
111
+ return cleaned_queries
112
+
113
+ except Exception as e:
114
+ logger.error(f"⚠️ 生成子問題時出錯: {e}")
115
+ return [question]
116
+
117
+ def _generate_hypothetical_document(self, sub_query: str) -> str:
118
+ """
119
+ 為子問題生成假設性文檔(與 HyDERAG 相同)
120
+
121
+ Args:
122
+ sub_query: 子問題
123
+
124
+ Returns:
125
+ 假設性文檔文本
126
+ """
127
+ is_chinese = PromptFormatter.detect_language(sub_query) == "zh"
128
+
129
+ if is_chinese:
130
+ prompt = f"""請針對以下問題,寫出一段約 {self.hypothetical_length} 字的專業技術檔案內容。
131
+ 這段內容應包含該領域常見的專業術語與原理說明,以便用於後續的語義檢索。
132
+ 請使用專業的術語和概念,即使你對某些細節不確定,也要包含相關的專業詞彙。
133
+
134
+ 問題: {sub_query}
135
+
136
+ 專業技術內容:"""
137
+ else:
138
+ prompt = f"""Please write a professional technical document of approximately {self.hypothetical_length} words in response to the following question.
139
+ This content should include common professional terminology and principle explanations in this field, to be used for subsequent semantic retrieval.
140
+ Please use professional terms and concepts, and include relevant professional vocabulary even if you are uncertain about some details.
141
+
142
+ Question: {sub_query}
143
+
144
+ Professional technical content:"""
145
+
146
+ try:
147
+ hypothetical_doc = self.llm.generate(
148
+ prompt=prompt,
149
+ temperature=self.temperature_hyde,
150
+ max_tokens=500
151
+ )
152
+
153
+ hypothetical_doc = hypothetical_doc.strip()
154
+
155
+ if not hypothetical_doc:
156
+ logger.warning(f"⚠️ 子問題 '{sub_query}' 的假設性文檔為空,使用子問題本身")
157
+ return sub_query
158
+
159
+ logger.debug(f"✅ 為子問題生成假設性文檔(長度: {len(hypothetical_doc)} 字符)")
160
+ return hypothetical_doc
161
+
162
+ except Exception as e:
163
+ logger.error(f"⚠️ 生成假設性文檔時出錯: {e}")
164
+ return sub_query
165
+
166
+ def _get_doc_id(self, doc: Dict) -> str:
167
+ """
168
+ 生成文檔的唯一標識符
169
+
170
+ Args:
171
+ doc: 文檔字典
172
+
173
+ Returns:
174
+ 唯一 ID
175
+ """
176
+ metadata = doc.get("metadata", {})
177
+ content = doc.get("content", "")
178
+
179
+ if "arxiv_id" in metadata and "chunk_index" in metadata:
180
+ return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
181
+ elif "file_path" in metadata and "chunk_index" in metadata:
182
+ return f"{metadata['file_path']}_{metadata['chunk_index']}"
183
+ else:
184
+ content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
185
+ return f"doc_{content_hash}"
186
+
187
+ def _process_subquery_with_hyde(
188
+ self,
189
+ sub_query: str,
190
+ metadata_filter: Optional[Dict] = None
191
+ ) -> tuple:
192
+ """
193
+ 處理單個子問題:生成假設性文檔並檢索
194
+
195
+ Args:
196
+ sub_query: 子問題
197
+ metadata_filter: 可選的 metadata 過濾條件
198
+
199
+ Returns:
200
+ (檢索結果列表, 假設性文檔)
201
+ """
202
+ try:
203
+ # 生成假設性文檔
204
+ hypothetical_doc = self._generate_hypothetical_document(sub_query)
205
+
206
+ # 使用假設性文檔檢索
207
+ results = self.vector_retriever.retrieve(
208
+ query=hypothetical_doc, # 使用假設性文檔而不是子問題
209
+ top_k=self.top_k_per_subquery,
210
+ metadata_filter=metadata_filter
211
+ )
212
+
213
+ return results, hypothetical_doc
214
+
215
+ except Exception as e:
216
+ logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
217
+ return [], ""
218
+
219
+ def query(
220
+ self,
221
+ question: str,
222
+ top_k: int = 5,
223
+ metadata_filter: Optional[Dict] = None,
224
+ return_sub_queries: bool = False,
225
+ return_hypothetical: bool = False
226
+ ) -> Dict:
227
+ """
228
+ 執行融合 RAG 檢索(不生成答案)
229
+
230
+ Args:
231
+ question: 原始問題
232
+ top_k: 返回前 k 個結果
233
+ metadata_filter: 可選的 metadata 過濾條件
234
+ return_sub_queries: 是否返回子問題列表
235
+ return_hypothetical: 是否返回假設性文檔字典(子問題 -> 假設性文檔)
236
+
237
+ Returns:
238
+ 包含檢索結果和統計資訊的字典
239
+ """
240
+ start_time = time.time()
241
+
242
+ # 第一步:生成子問題
243
+ logger.info(f"🔍 拆解問題: '{question}'")
244
+ sub_queries = self._generate_sub_queries(question)
245
+ logger.info(f"✅ 生成 {len(sub_queries)} 個子問題")
246
+
247
+ # 第二步:為每個子問題生成假設性文檔並檢索
248
+ logger.info(f"📚 為每個子問題生成假設性文檔並檢索...")
249
+ unique_docs = {}
250
+ hypothetical_docs = {}
251
+
252
+ if self.enable_parallel and len(sub_queries) > 1:
253
+ # 並行處理
254
+ logger.info(f"🔄 並行處理 {len(sub_queries)} 個子問題...")
255
+ with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor:
256
+ future_to_query = {
257
+ executor.submit(self._process_subquery_with_hyde, sq, metadata_filter): sq
258
+ for sq in sub_queries
259
+ }
260
+
261
+ for future in as_completed(future_to_query):
262
+ sub_query = future_to_query[future]
263
+ try:
264
+ results, hypo_doc = future.result()
265
+ hypothetical_docs[sub_query] = hypo_doc
266
+
267
+ logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(results)} 個結果")
268
+
269
+ for doc in results:
270
+ doc_id = self._get_doc_id(doc)
271
+ if doc_id not in unique_docs:
272
+ unique_docs[doc_id] = doc
273
+ else:
274
+ # 保留分數更高的
275
+ existing_score = unique_docs[doc_id].get('score', 0)
276
+ new_score = doc.get('score', 0)
277
+ if new_score > existing_score:
278
+ unique_docs[doc_id] = doc
279
+ except Exception as e:
280
+ logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
281
+ else:
282
+ # 串行處理
283
+ logger.info(f"🔄 串行處理 {len(sub_queries)} 個子問題...")
284
+ for sub_query in sub_queries:
285
+ results, hypo_doc = self._process_subquery_with_hyde(sub_query, metadata_filter)
286
+ hypothetical_docs[sub_query] = hypo_doc
287
+
288
+ logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(results)} 個結果")
289
+
290
+ for doc in results:
291
+ doc_id = self._get_doc_id(doc)
292
+ if doc_id not in unique_docs:
293
+ unique_docs[doc_id] = doc
294
+ else:
295
+ existing_score = unique_docs[doc_id].get('score', 0)
296
+ new_score = doc.get('score', 0)
297
+ if new_score > existing_score:
298
+ unique_docs[doc_id] = doc
299
+
300
+ # 第三步:排序並返回前 top_k
301
+ result_list = list(unique_docs.values())
302
+ result_list.sort(key=lambda x: x.get('score', 0), reverse=True)
303
+ final_results = result_list[:top_k]
304
+
305
+ elapsed_time = time.time() - start_time
306
+ logger.info(f"✅ 找到 {len(final_results)} 個唯一文檔(去重後,總共 {len(result_list)} 個)")
307
+
308
+ return {
309
+ "results": final_results,
310
+ "total_docs_found": len(result_list),
311
+ "sub_queries": sub_queries if return_sub_queries else None,
312
+ "hypothetical_documents": hypothetical_docs if return_hypothetical else None,
313
+ "elapsed_time": elapsed_time
314
+ }
315
+
316
+ def generate_answer(
317
+ self,
318
+ question: str,
319
+ formatter: PromptFormatter,
320
+ top_k: int = 5,
321
+ metadata_filter: Optional[Dict] = None,
322
+ document_type: str = "general",
323
+ return_sub_queries: bool = False,
324
+ return_hypothetical: bool = False
325
+ ) -> Dict:
326
+ """
327
+ 完整的融合 RAG 流程:檢索 + 生成答案
328
+
329
+ Args:
330
+ question: 原始問題
331
+ formatter: Prompt 格式化器
332
+ top_k: 用於生成答案的文檔數量
333
+ metadata_filter: 可選的 metadata 過濾條件
334
+ document_type: 文檔類型 ("paper", "cv", "general")
335
+ return_sub_queries: 是否返回子問題列表
336
+ return_hypothetical: 是否返回假設性文檔字典
337
+
338
+ Returns:
339
+ 包含檢索結果、生成的答案和統計資訊的字典
340
+ """
341
+ start_time = time.time()
342
+
343
+ # 檢索
344
+ retrieval_result = self.query(
345
+ question=question,
346
+ top_k=top_k,
347
+ metadata_filter=metadata_filter,
348
+ return_sub_queries=return_sub_queries,
349
+ return_hypothetical=return_hypothetical
350
+ )
351
+
352
+ if not retrieval_result["results"]:
353
+ return {
354
+ **retrieval_result,
355
+ "answer": "抱歉,未找到相關文檔來回答此問題。",
356
+ "formatted_context": None,
357
+ "answer_time": 0.0,
358
+ "total_time": retrieval_result["elapsed_time"]
359
+ }
360
+
361
+ # 格式化上下文
362
+ formatted_context = formatter.format_context(
363
+ retrieval_result["results"],
364
+ document_type=document_type
365
+ )
366
+
367
+ # 創建 prompt(使用原始問題)
368
+ prompt = formatter.create_prompt(
369
+ question,
370
+ formatted_context,
371
+ document_type=document_type
372
+ )
373
+
374
+ # 生成回答
375
+ logger.info("🤖 生成回答中...")
376
+ answer_start = time.time()
377
+ try:
378
+ answer = self.llm.generate(
379
+ prompt=prompt,
380
+ temperature=0.7,
381
+ max_tokens=2048
382
+ )
383
+ answer_time = time.time() - answer_start
384
+ logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
385
+ except Exception as e:
386
+ logger.error(f"❌ 生成回答時出錯: {e}")
387
+ answer = f"生成回答時出錯: {e}"
388
+ answer_time = time.time() - answer_start
389
+
390
+ total_time = time.time() - start_time
391
+
392
+ return {
393
+ **retrieval_result,
394
+ "answer": answer,
395
+ "formatted_context": formatted_context,
396
+ "answer_time": answer_time,
397
+ "total_time": total_time
398
+ }
399
+
src/hyde_rag.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HyDE (Hypothetical Document Embeddings) RAG:使用假設性文檔改善檢索
3
+ """
4
+ from typing import List, Dict, Optional
5
+ from .retrievers.reranker import RAGPipeline
6
+ from .retrievers.vector_retriever import VectorRetriever
7
+ from .prompt_formatter import PromptFormatter
8
+ from .llm_integration import OllamaLLM
9
+ import time
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class HyDERAG:
16
+ """使用 HyDE (Hypothetical Document Embeddings) 的 RAG 系統"""
17
+
18
+ def __init__(
19
+ self,
20
+ rag_pipeline: RAGPipeline,
21
+ vector_retriever: VectorRetriever,
22
+ llm: OllamaLLM,
23
+ hypothetical_length: int = 200,
24
+ temperature: float = 0.7
25
+ ):
26
+ """
27
+ 初始化 HyDE RAG
28
+
29
+ Args:
30
+ rag_pipeline: RAG 管線實例(用於最終答案生成)
31
+ vector_retriever: 向量檢索器(用於基於假設性文檔的檢索)
32
+ llm: LLM 實例(用於生成假設性文檔)
33
+ hypothetical_length: 假設性文檔的目標長度(字符數)
34
+ temperature: 生成假設性文檔時的溫度參數(建議 0.7,以獲得更多專業術語)
35
+ """
36
+ self.rag_pipeline = rag_pipeline
37
+ self.vector_retriever = vector_retriever
38
+ self.llm = llm
39
+ self.hypothetical_length = hypothetical_length
40
+ self.temperature = temperature
41
+
42
+ def _generate_hypothetical_document(self, question: str) -> str:
43
+ """
44
+ 生成假設性文檔(Hypothetical Document)
45
+
46
+ Args:
47
+ question: 用戶問題
48
+
49
+ Returns:
50
+ 假設性文檔文本
51
+ """
52
+ # 檢測語言
53
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
54
+
55
+ if is_chinese:
56
+ prompt = f"""請針對以下問題,寫出一段約 {self.hypothetical_length} 字的專業技術檔案內容。
57
+ 這段內容應包含該領域常見的專業術語與原理說明,以便用於後續的語義檢索。
58
+ 請使用專業的術語和概念,即使你對某些細節不確定,也要包含相關的專業詞彙。
59
+
60
+ 問題: {question}
61
+
62
+ 專業技術內容:"""
63
+ else:
64
+ prompt = f"""Please write a professional technical document of approximately {self.hypothetical_length} words in response to the following question.
65
+ This content should include common professional terminology and principle explanations in this field, to be used for subsequent semantic retrieval.
66
+ Please use professional terms and concepts, and include relevant professional vocabulary even if you are uncertain about some details.
67
+
68
+ Question: {question}
69
+
70
+ Professional technical content:"""
71
+
72
+ try:
73
+ hypothetical_doc = self.llm.generate(
74
+ prompt=prompt,
75
+ temperature=self.temperature, # 較高的溫度以獲得更多專業術語
76
+ max_tokens=500
77
+ )
78
+
79
+ # 清理輸出
80
+ hypothetical_doc = hypothetical_doc.strip()
81
+
82
+ if not hypothetical_doc:
83
+ logger.warning("⚠️ 生成的假設性文檔為空,使用原始問題")
84
+ return question
85
+
86
+ logger.info(f"✅ 生成假設性文檔(長度: {len(hypothetical_doc)} 字符)")
87
+ return hypothetical_doc
88
+
89
+ except Exception as e:
90
+ logger.error(f"⚠️ 生成假設性文檔時出錯: {e}")
91
+ # 回退到使用原始問題
92
+ return question
93
+
94
+ def query(
95
+ self,
96
+ question: str,
97
+ top_k: int = 5,
98
+ metadata_filter: Optional[Dict] = None,
99
+ return_hypothetical: bool = False
100
+ ) -> Dict:
101
+ """
102
+ 執行 HyDE 檢索(不生成答案)
103
+
104
+ Args:
105
+ question: 原始問題
106
+ top_k: 返回前 k 個結果
107
+ metadata_filter: 可選的 metadata 過濾條件
108
+ return_hypothetical: 是否在結果中包含假設性文檔
109
+
110
+ Returns:
111
+ 包含檢索結果和統計資訊的字典
112
+ """
113
+ start_time = time.time()
114
+
115
+ # 第一步:生成假設性文檔
116
+ logger.info(f"🔍 生成假設性文檔: '{question}'")
117
+ hypothetical_doc = self._generate_hypothetical_document(question)
118
+
119
+ # 第二步:使用假設性文檔進行檢索
120
+ logger.info(f"📚 使用假設性文檔進行檢索...")
121
+ results = self.vector_retriever.retrieve(
122
+ query=hypothetical_doc, # 使用假設性文檔而不是原始問題
123
+ top_k=top_k,
124
+ metadata_filter=metadata_filter
125
+ )
126
+
127
+ elapsed_time = time.time() - start_time
128
+ logger.info(f"✅ 找到 {len(results)} 個結果(耗時: {elapsed_time:.2f}s)")
129
+
130
+ result = {
131
+ "results": results,
132
+ "total_docs_found": len(results),
133
+ "hypothetical_document": hypothetical_doc if return_hypothetical else None,
134
+ "elapsed_time": elapsed_time
135
+ }
136
+
137
+ return result
138
+
139
+ def generate_answer(
140
+ self,
141
+ question: str,
142
+ formatter: PromptFormatter,
143
+ top_k: int = 5,
144
+ metadata_filter: Optional[Dict] = None,
145
+ document_type: str = "general",
146
+ return_hypothetical: bool = False
147
+ ) -> Dict:
148
+ """
149
+ 完整的 HyDE RAG 流程:生成假設性文檔 -> 檢索 -> 生成答案
150
+
151
+ Args:
152
+ question: 原始問題
153
+ formatter: Prompt 格式化器
154
+ top_k: 用於生成答案的文檔數量
155
+ metadata_filter: 可選的 metadata 過濾條件
156
+ document_type: 文檔類型 ("paper", "cv", "general")
157
+ return_hypothetical: 是否在結果中包含假設性文檔
158
+
159
+ Returns:
160
+ 包含檢索結果、生成的答案和統計資訊的字典
161
+ """
162
+ start_time = time.time()
163
+
164
+ # 第一步:生成假設性文檔
165
+ logger.info(f"🔍 生成假設性文檔: '{question}'")
166
+ hypothetical_start = time.time()
167
+ hypothetical_doc = self._generate_hypothetical_document(question)
168
+ hypothetical_time = time.time() - hypothetical_start
169
+
170
+ # 第二步:使用假設性文檔進行檢索
171
+ logger.info(f"📚 使用假設性文檔進行檢索...")
172
+ retrieval_start = time.time()
173
+ results = self.vector_retriever.retrieve(
174
+ query=hypothetical_doc, # 使用假設性文檔而不是原始問題
175
+ top_k=top_k,
176
+ metadata_filter=metadata_filter
177
+ )
178
+ retrieval_time = time.time() - retrieval_start
179
+
180
+ if not results:
181
+ return {
182
+ "results": [],
183
+ "total_docs_found": 0,
184
+ "hypothetical_document": hypothetical_doc if return_hypothetical else None,
185
+ "elapsed_time": retrieval_time + hypothetical_time,
186
+ "answer": "抱歉,未找到相關文檔來回答此問題。",
187
+ "formatted_context": None,
188
+ "answer_time": 0.0,
189
+ "total_time": retrieval_time + hypothetical_time
190
+ }
191
+
192
+ # 第三步:格式化上下文
193
+ formatted_context = formatter.format_context(
194
+ results,
195
+ document_type=document_type
196
+ )
197
+
198
+ # 第四步:創建 prompt(使用原始問題,而不是假設性文檔)
199
+ prompt = formatter.create_prompt(
200
+ question, # 使用原始問題生成答案
201
+ formatted_context,
202
+ document_type=document_type
203
+ )
204
+
205
+ # 第五步:生成回答
206
+ logger.info("🤖 生成回答中...")
207
+ answer_start = time.time()
208
+ try:
209
+ answer = self.llm.generate(
210
+ prompt=prompt,
211
+ temperature=0.7,
212
+ max_tokens=2048
213
+ )
214
+ answer_time = time.time() - answer_start
215
+ logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
216
+ except Exception as e:
217
+ logger.error(f"❌ 生成回答時出錯: {e}")
218
+ answer = f"生成回答時出錯: {e}"
219
+ answer_time = time.time() - answer_start
220
+
221
+ total_time = time.time() - start_time
222
+
223
+ return {
224
+ "results": results,
225
+ "total_docs_found": len(results),
226
+ "hypothetical_document": hypothetical_doc if return_hypothetical else None,
227
+ "elapsed_time": retrieval_time + hypothetical_time,
228
+ "hypothetical_time": hypothetical_time,
229
+ "retrieval_time": retrieval_time,
230
+ "answer": answer,
231
+ "formatted_context": formatted_context,
232
+ "answer_time": answer_time,
233
+ "total_time": total_time
234
+ }
235
+
src/llm_integration.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM 集成模組:使用 Ollama 進行本地 LLM 推理
3
+ """
4
+ from typing import Optional, Dict, List
5
+ import logging
6
+ import requests
7
+ import json
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class OllamaLLM:
13
+ """使用 Ollama 進行本地 LLM 推理"""
14
+
15
+ # 適合 16GB MacBook Air 的模型推薦
16
+ RECOMMENDED_MODELS = {
17
+ "deepseek-r1:7b": {
18
+ "name": "deepseek-r1:7b",
19
+ "description": "DeepSeek R1 7B - 大模型,高質量",
20
+ "memory_required": "~8GB",
21
+ "quality": "優秀"
22
+ },
23
+ "llama3.2:3b": {
24
+ "name": "llama3.2:3b",
25
+ "description": "Meta Llama 3.2 3B - 輕量級,適合 16GB 內存",
26
+ "memory_required": "~4GB",
27
+ "quality": "良好"
28
+ },
29
+ "llama3.2:1b": {
30
+ "name": "llama3.2:1b",
31
+ "description": "Meta Llama 3.2 1B - 極輕量級,快速響應",
32
+ "memory_required": "~2GB",
33
+ "quality": "基礎"
34
+ },
35
+ "phi3:mini": {
36
+ "name": "phi3:mini",
37
+ "description": "Microsoft Phi-3 Mini - 小模型,高質量",
38
+ "memory_required": "~3GB",
39
+ "quality": "良好"
40
+ },
41
+ "gemma:2b": {
42
+ "name": "gemma:2b",
43
+ "description": "Google Gemma 2B - 輕量級,開源",
44
+ "memory_required": "~3GB",
45
+ "quality": "良好"
46
+ },
47
+ "mistral:7b": {
48
+ "name": "mistral:7b",
49
+ "description": "Mistral 7B - 較大但質量高(如果內存足夠)",
50
+ "memory_required": "~8GB",
51
+ "quality": "優秀"
52
+ }
53
+ }
54
+
55
+ def __init__(
56
+ self,
57
+ model_name: str = "llama3.2:3b",
58
+ base_url: str = "http://localhost:11434",
59
+ timeout: int = 120
60
+ ):
61
+ """
62
+ 初始化 Ollama LLM
63
+
64
+ Args:
65
+ model_name: Ollama 模型名稱(預設: llama3.2:3b)
66
+ base_url: Ollama API 基礎 URL
67
+ timeout: 請求超時時間(秒)
68
+ """
69
+ self.model_name = model_name
70
+ self.base_url = base_url.rstrip('/')
71
+ self.timeout = timeout
72
+ self.api_url = f"{self.base_url}/api"
73
+
74
+ # 檢查模型是否在推薦列表中
75
+ if model_name not in self.RECOMMENDED_MODELS:
76
+ logger.warning(
77
+ f"⚠️ 模型 '{model_name}' 不在推薦列表中。"
78
+ f"推薦的模型: {', '.join(self.RECOMMENDED_MODELS.keys())}"
79
+ )
80
+
81
+ logger.info(f"✅ Ollama LLM 初始化完成 (模型: {model_name})")
82
+
83
+ def _check_ollama_connection(self) -> bool:
84
+ """
85
+ 檢查 Ollama 服務是否可用
86
+
87
+ Returns:
88
+ 是否連接成功
89
+ """
90
+ try:
91
+ response = requests.get(f"{self.base_url}/api/tags", timeout=5)
92
+ return response.status_code == 200
93
+ except Exception as e:
94
+ logger.error(f"❌ 無法連接到 Ollama: {e}")
95
+ logger.error(f" 請確保 Ollama 正在運行: ollama serve")
96
+ return False
97
+
98
+ def _check_model_available(self) -> bool:
99
+ """
100
+ 檢查模型是否已下載
101
+
102
+ Returns:
103
+ 模型是否可用
104
+ """
105
+ try:
106
+ response = requests.get(f"{self.base_url}/api/tags", timeout=5)
107
+ if response.status_code == 200:
108
+ models = response.json().get('models', [])
109
+ model_names = [m.get('name', '') for m in models]
110
+ return any(self.model_name in name for name in model_names)
111
+ return False
112
+ except Exception as e:
113
+ logger.error(f"❌ 檢查模型時出錯: {e}")
114
+ return False
115
+
116
+ def generate(
117
+ self,
118
+ prompt: str,
119
+ temperature: float = 0.7,
120
+ max_tokens: Optional[int] = None,
121
+ stream: bool = False
122
+ ) -> str:
123
+ """
124
+ 生成回答
125
+
126
+ Args:
127
+ prompt: 輸入 prompt
128
+ temperature: 溫度參數(0.0-1.0),控制隨機性
129
+ max_tokens: 最大生成 token 數(None 表示使用模型預設)
130
+ stream: 是否使用流式輸出
131
+
132
+ Returns:
133
+ 生成的回答
134
+ """
135
+ # 檢查連接
136
+ if not self._check_ollama_connection():
137
+ raise ConnectionError(
138
+ f"無法連接到 Ollama 服務 ({self.base_url})\n"
139
+ f"請確保 Ollama 正在運行:\n"
140
+ f" 1. 安裝 Ollama: https://ollama.ai\n"
141
+ f" 2. 啟動服務: ollama serve\n"
142
+ f" 3. 下載模型: ollama pull {self.model_name}"
143
+ )
144
+
145
+ # 檢查模型
146
+ if not self._check_model_available():
147
+ logger.warning(
148
+ f"⚠️ 模型 '{self.model_name}' 可能未下載。"
149
+ f"請運行: ollama pull {self.model_name}"
150
+ )
151
+
152
+ # 準備請求參數
153
+ payload = {
154
+ "model": self.model_name,
155
+ "prompt": prompt,
156
+ "stream": stream,
157
+ "options": {
158
+ "temperature": temperature,
159
+ }
160
+ }
161
+
162
+ if max_tokens:
163
+ payload["options"]["num_predict"] = max_tokens
164
+
165
+ try:
166
+ # 發送請求
167
+ response = requests.post(
168
+ f"{self.api_url}/generate",
169
+ json=payload,
170
+ timeout=self.timeout,
171
+ stream=stream
172
+ )
173
+
174
+ if response.status_code != 200:
175
+ error_msg = response.text
176
+ raise RuntimeError(f"Ollama API 錯誤: {error_msg}")
177
+
178
+ if stream:
179
+ # 流式處理
180
+ full_response = ""
181
+ for line in response.iter_lines():
182
+ if line:
183
+ try:
184
+ data = json.loads(line)
185
+ if 'response' in data:
186
+ chunk = data['response']
187
+ full_response += chunk
188
+ print(chunk, end='', flush=True)
189
+ if data.get('done', False):
190
+ break
191
+ except json.JSONDecodeError:
192
+ continue
193
+ print() # 換行
194
+ return full_response
195
+ else:
196
+ # 非流式處理
197
+ data = response.json()
198
+ return data.get('response', '')
199
+
200
+ except requests.exceptions.Timeout:
201
+ raise TimeoutError(
202
+ f"請求超時({self.timeout}秒)。"
203
+ f"可以嘗試增加 timeout 或使用更小的模型。"
204
+ )
205
+ except requests.exceptions.ConnectionError:
206
+ raise ConnectionError(
207
+ f"無法連接到 Ollama 服務。"
208
+ f"請確保 Ollama 正在運行:ollama serve"
209
+ )
210
+ except Exception as e:
211
+ logger.error(f"❌ 生成回答時出錯: {e}")
212
+ raise
213
+
214
+ def list_available_models(self) -> List[str]:
215
+ """
216
+ 列出本地可用的模型
217
+
218
+ Returns:
219
+ 可用模型名稱列表
220
+ """
221
+ try:
222
+ response = requests.get(f"{self.base_url}/api/tags", timeout=5)
223
+ if response.status_code == 200:
224
+ models = response.json().get('models', [])
225
+ return [m.get('name', '') for m in models]
226
+ return []
227
+ except Exception as e:
228
+ logger.error(f"❌ 獲取模型列表時出錯: {e}")
229
+ return []
230
+
231
+ @classmethod
232
+ def print_recommended_models(cls):
233
+ """打印推薦的模型列表"""
234
+ print("\n" + "="*60)
235
+ print("適合 16GB MacBook Air 的 Ollama 模型推薦")
236
+ print("="*60)
237
+ print()
238
+
239
+ for model_key, info in cls.RECOMMENDED_MODELS.items():
240
+ print(f"📦 {info['name']}")
241
+ print(f" 描述: {info['description']}")
242
+ print(f" 內存需求: {info['memory_required']}")
243
+ print(f" 質量: {info['quality']}")
244
+ print(f" 下載命令: ollama pull {info['name']}")
245
+ print()
246
+
src/prompt_formatter.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt 格式化模組:將檢索結果格式化為 LLM 可讀的上下文
3
+ """
4
+ from typing import List, Dict, Optional
5
+ import re
6
+
7
+
8
+ class PromptFormatter:
9
+ """格式化檢索結果供 LLM 使用"""
10
+
11
+ def __init__(
12
+ self,
13
+ include_metadata: bool = True,
14
+ format_style: str = "detailed",
15
+ max_context_length: Optional[int] = None,
16
+ auto_detect_language: bool = True
17
+ ):
18
+ """
19
+ 初始化 Prompt 格式化器
20
+
21
+ Args:
22
+ include_metadata: 是否包含來源資訊
23
+ format_style: 格式風格 ("detailed", "simple", "minimal")
24
+ max_context_length: 最大上下文長度(字符數),None 表示不限制
25
+ auto_detect_language: 是否自動檢測語言並相應調整回答語言
26
+ """
27
+ self.include_metadata = include_metadata
28
+ self.format_style = format_style
29
+ self.max_context_length = max_context_length
30
+ self.auto_detect_language = auto_detect_language
31
+
32
+ @staticmethod
33
+ def detect_language(text: str) -> str:
34
+ """
35
+ 檢測文本的主要語言
36
+
37
+ Args:
38
+ text: 輸入文本
39
+
40
+ Returns:
41
+ "zh" 表示中文,"en" 表示英文
42
+ """
43
+ # 檢查是否包含中文字符(CJK 統一表意文字範圍)
44
+ chinese_pattern = re.compile(r'[\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff]')
45
+ chinese_chars = len(chinese_pattern.findall(text))
46
+
47
+ # 計算中文字符比例
48
+ total_chars = len([c for c in text if c.isalnum() or c.isspace()])
49
+
50
+ if total_chars == 0:
51
+ return "en" # 預設英文
52
+
53
+ chinese_ratio = chinese_chars / total_chars if total_chars > 0 else 0
54
+
55
+ # 如果中文字符比例超過 20%,認為是中文
56
+ if chinese_ratio > 0.2:
57
+ return "zh"
58
+ else:
59
+ return "en"
60
+
61
+ def get_system_prompt(self, language: str = "zh", document_type: str = "general") -> str:
62
+ """
63
+ 根據語言和文檔類型獲取系統提示詞
64
+
65
+ Args:
66
+ language: 語言代碼 ("zh" 或 "en")
67
+ document_type: 文檔類型 ("paper", "cv", "general")
68
+ "paper": 學術論文
69
+ "cv": 履歷/履歷
70
+ "general": 通用文檔(預設)
71
+
72
+ Returns:
73
+ 系統提示詞字符串
74
+ """
75
+ if language == "zh":
76
+ if document_type == "paper":
77
+ return (
78
+ "你是一個專業的 AI 研究助手,專門回答關於機器學習、"
79
+ "深度學習和自然語言處理的問題。\n\n"
80
+ "請基於以下提供的學術論文片段來回答用戶的問題。"
81
+ "每個片段都標註了來源論文的資訊。\n\n"
82
+ "回答要求:\n"
83
+ "1. 基於提供的上下文回答問題\n"
84
+ "2. 如果上下文不足以回答,請明確說明\n"
85
+ "3. 在回答中引用具體的論文來源(使用 arXiv ID)\n"
86
+ "4. 如果不同論文有不同觀點,請分別說明\n"
87
+ "5. 保持回答簡潔、準確、專業\n"
88
+ "6. **重要:請使用與用戶問題相同的語言回答**\n"
89
+ )
90
+ elif document_type == "cv":
91
+ return (
92
+ "你是一個專業的 AI 助手,專門幫助分析和介紹簡歷(CV)內容。\n\n"
93
+ "請基於以下提供的文檔片段來回答用戶的問題。"
94
+ "這些片段來自一份簡歷或履歷表。\n\n"
95
+ "回答要求:\n"
96
+ "1. 基於提供的上下文回答問題\n"
97
+ "2. 如果上下文不足以回答,請明確說明\n"
98
+ "3. 在回答中引用具體的文檔內容\n"
99
+ "4. 保持回答簡潔、準確、專業\n"
100
+ "5. **重要:請使用與用戶問題相同的語言回答**\n"
101
+ "6. **請理解:這些片段就是簡歷的內容,請直接基於這些內容回答問題**\n"
102
+ )
103
+ else: # general
104
+ return (
105
+ "你是一個專業的 AI 助手。\n\n"
106
+ "請基於以下提供的文檔片段來回答用戶的問題。"
107
+ "每個片段都標註了來源資訊。\n\n"
108
+ "回答要求:\n"
109
+ "1. 基於提供的上下文回答問題\n"
110
+ "2. 如果上下文不足以回答,請明確說明\n"
111
+ "3. 在回答中引用具體的文檔內容\n"
112
+ "4. 保持回答簡潔、準確、專業\n"
113
+ "5. **重要:請使用與用戶問題相同的語言回答**\n"
114
+ )
115
+ else: # English
116
+ if document_type == "paper":
117
+ return (
118
+ "You are a professional AI research assistant specializing in "
119
+ "machine learning, deep learning, and natural language processing.\n\n"
120
+ "Please answer the user's question based on the provided academic paper excerpts. "
121
+ "Each excerpt is labeled with source paper information.\n\n"
122
+ "Answer requirements:\n"
123
+ "1. Answer the question based on the provided context\n"
124
+ "2. If the context is insufficient, clearly state so\n"
125
+ "3. Cite specific paper sources in your answer (using arXiv ID)\n"
126
+ "4. If different papers have different viewpoints, explain them separately\n"
127
+ "5. Keep answers concise, accurate, and professional\n"
128
+ "6. **Important: Please answer in the same language as the user's question**\n"
129
+ )
130
+ elif document_type == "cv":
131
+ return (
132
+ "You are a professional AI assistant specializing in analyzing and introducing CV (Curriculum Vitae) content.\n\n"
133
+ "Please answer the user's question based on the provided document excerpts. "
134
+ "These excerpts are from a CV or resume.\n\n"
135
+ "Answer requirements:\n"
136
+ "1. Answer the question based on the provided context\n"
137
+ "2. If the context is insufficient, clearly state so\n"
138
+ "3. Cite specific document content in your answer\n"
139
+ "4. Keep answers concise, accurate, and professional\n"
140
+ "5. **Important: Please answer in the same language as the user's question**\n"
141
+ "6. **Please understand: These excerpts ARE the CV content. Please answer directly based on this content.**\n"
142
+ )
143
+ else: # general
144
+ return (
145
+ "You are a professional AI assistant.\n\n"
146
+ "Please answer the user's question based on the provided document excerpts. "
147
+ "Each excerpt is labeled with source information.\n\n"
148
+ "Answer requirements:\n"
149
+ "1. Answer the question based on the provided context\n"
150
+ "2. If the context is insufficient, clearly state so\n"
151
+ "3. Cite specific document content in your answer\n"
152
+ "4. Keep answers concise, accurate, and professional\n"
153
+ "5. **Important: Please answer in the same language as the user's question**\n"
154
+ )
155
+
156
+ def format_context(
157
+ self,
158
+ results: List[Dict],
159
+ include_metadata: Optional[bool] = None,
160
+ format_style: Optional[str] = None,
161
+ document_type: str = "general"
162
+ ) -> str:
163
+ """
164
+ 格式化檢索結果為 LLM 可讀的上下文
165
+
166
+ Args:
167
+ results: 檢索結果列表
168
+ include_metadata: 是否包含來源資訊(覆蓋初始化參數)
169
+ format_style: 格式風格(覆蓋初始化參數)
170
+ document_type: 文檔類型 ("paper", "cv", "general"),用於調整格式
171
+
172
+ Returns:
173
+ 格式化後的上下文字符串
174
+ """
175
+ if include_metadata is None:
176
+ include_metadata = self.include_metadata
177
+ if format_style is None:
178
+ format_style = self.format_style
179
+
180
+ if not results:
181
+ # 根據格式風格選擇語言
182
+ if format_style == "detailed" or format_style == "simple":
183
+ return "(未找到相關文檔片段)"
184
+ else:
185
+ return "(No relevant excerpts found)"
186
+
187
+ formatted_parts = []
188
+
189
+ for i, result in enumerate(results, 1):
190
+ content = result.get("content", "")
191
+ metadata = result.get("metadata", {})
192
+
193
+ if not include_metadata:
194
+ # 不包含來源資訊,直接使用內容
195
+ formatted_parts.append(f"{content}\n")
196
+ elif format_style == "detailed":
197
+ # 詳細格式:根據文檔類型調整顯示資訊
198
+ if document_type == "cv":
199
+ # CV 格式:顯示檔案名和路徑
200
+ source_info = (
201
+ f"[來源 {i}]\n"
202
+ f"檔案標題: {metadata.get('title', 'N/A')}\n"
203
+ )
204
+ if 'file_path' in metadata:
205
+ source_info += f"檔案路徑: {metadata.get('file_path', 'N/A')}\n"
206
+ if 'file_type' in metadata:
207
+ source_info += f"檔案類型: {metadata.get('file_type', 'N/A')}\n"
208
+ elif document_type == "paper":
209
+ # 論文格式:顯示論文資訊
210
+ authors = metadata.get('authors', [])
211
+ if isinstance(authors, str):
212
+ authors_str = authors
213
+ elif isinstance(authors, list):
214
+ authors_str = ', '.join(authors[:3]) # 最多顯示 3 個作者
215
+ if len(authors) > 3:
216
+ authors_str += f" 等 {len(authors)} 位作者"
217
+ else:
218
+ authors_str = 'N/A'
219
+
220
+ source_info = (
221
+ f"[來源 {i}]\n"
222
+ f"論文標題: {metadata.get('title', 'N/A')}\n"
223
+ f"arXiv ID: {metadata.get('arxiv_id', 'N/A')}\n"
224
+ f"作者: {authors_str}\n"
225
+ f"發布日期: {metadata.get('published', 'N/A')}\n"
226
+ )
227
+ else:
228
+ # 通用格式:顯示可用的資訊
229
+ source_info = f"[來源 {i}]\n"
230
+ if 'title' in metadata:
231
+ source_info += f"標題: {metadata.get('title', 'N/A')}\n"
232
+ if 'file_path' in metadata:
233
+ source_info += f"檔案: {metadata.get('file_path', 'N/A')}\n"
234
+ if 'arxiv_id' in metadata:
235
+ source_info += f"arXiv ID: {metadata.get('arxiv_id', 'N/A')}\n"
236
+
237
+ # 添加相關性分數(如果有的話)
238
+ rerank_score = result.get('rerank_score')
239
+ hybrid_score = result.get('hybrid_score')
240
+ if rerank_score is not None:
241
+ source_info += f"相關性分數: {rerank_score:.4f}\n"
242
+ elif hybrid_score is not None:
243
+ source_info += f"相關性分數: {hybrid_score:.4f}\n"
244
+
245
+ source_info += f"---\n{content}\n"
246
+ formatted_parts.append(source_info)
247
+
248
+ elif format_style == "simple":
249
+ # 簡單格式:只包含關鍵資訊
250
+ title = metadata.get('title', 'N/A')
251
+ if document_type == "paper" and 'arxiv_id' in metadata:
252
+ arxiv_id = metadata.get('arxiv_id', 'N/A')
253
+ source_info = (
254
+ f"[來源 {i}: {title} "
255
+ f"(arXiv:{arxiv_id})]\n"
256
+ f"{content}\n"
257
+ )
258
+ elif document_type == "cv" and 'file_path' in metadata:
259
+ file_path = metadata.get('file_path', 'N/A')
260
+ source_info = (
261
+ f"[來源 {i}: {title} "
262
+ f"({file_path})]\n"
263
+ f"{content}\n"
264
+ )
265
+ else:
266
+ source_info = (
267
+ f"[來源 {i}: {title}]\n"
268
+ f"{content}\n"
269
+ )
270
+ formatted_parts.append(source_info)
271
+ else: # minimal
272
+ # 最小格式:只標註來源
273
+ if document_type == "paper" and 'arxiv_id' in metadata:
274
+ arxiv_id = metadata.get('arxiv_id', 'N/A')
275
+ source_info = (
276
+ f"[arXiv:{arxiv_id}]\n"
277
+ f"{content}\n"
278
+ )
279
+ elif 'title' in metadata:
280
+ title = metadata.get('title', 'N/A')
281
+ source_info = (
282
+ f"[{title}]\n"
283
+ f"{content}\n"
284
+ )
285
+ else:
286
+ source_info = (
287
+ f"[來源 {i}]\n"
288
+ f"{content}\n"
289
+ )
290
+ formatted_parts.append(source_info)
291
+
292
+ formatted_text = "\n" + "="*60 + "\n".join(formatted_parts)
293
+
294
+ # 如果設置了最大長度,進行截斷
295
+ if self.max_context_length and len(formatted_text) > self.max_context_length:
296
+ # 從後往前截斷,保留格式
297
+ formatted_text = formatted_text[:self.max_context_length]
298
+ # 確保最後一個來源資訊完整
299
+ last_separator = formatted_text.rfind("="*60)
300
+ if last_separator > 0:
301
+ formatted_text = formatted_text[:last_separator] + "\n(內容已截斷...)"
302
+
303
+ return formatted_text
304
+
305
+ def create_prompt(
306
+ self,
307
+ query: str,
308
+ context: str,
309
+ system_prompt: Optional[str] = None,
310
+ document_type: str = "general"
311
+ ) -> str:
312
+ """
313
+ 創建完整的 LLM prompt
314
+
315
+ Args:
316
+ query: 用戶查詢
317
+ context: 格式化後的上下文
318
+ system_prompt: 可選的系統提示詞(如果為 None,會根據語言和文檔類型自動選擇)
319
+ document_type: 文檔類型 ("paper", "cv", "general")
320
+
321
+ Returns:
322
+ 完整的 prompt 字符串
323
+ """
324
+ # 自動檢測語言並選擇相應的系統提示詞
325
+ if system_prompt is None and self.auto_detect_language:
326
+ detected_language = self.detect_language(query)
327
+ system_prompt = self.get_system_prompt(detected_language, document_type)
328
+ elif system_prompt is None:
329
+ # 如果禁用自動檢測,使用中文作為預設
330
+ system_prompt = self.get_system_prompt("zh", document_type)
331
+
332
+ # 根據檢測到的語言選擇提示詞格式
333
+ detected_language = self.detect_language(query) if self.auto_detect_language else "zh"
334
+
335
+ # 根據文檔類型選擇不同的提示詞結尾
336
+ if document_type == "paper":
337
+ if detected_language == "zh":
338
+ ending = "## 請基於上述文獻片段回答問題,並在回答中引用具體的論文來源。"
339
+ else:
340
+ ending = "## Please answer the question based on the above document excerpts and cite specific paper sources in your answer."
341
+ else:
342
+ if detected_language == "zh":
343
+ ending = "## 請基於上述文檔片段回答問題,並在回答中引用具體的文檔內容。"
344
+ else:
345
+ ending = "## Please answer the question based on the above document excerpts and cite specific document content in your answer."
346
+
347
+ if detected_language == "zh":
348
+ prompt = f"""{system_prompt}
349
+
350
+ ## 相關文檔片段:
351
+
352
+ {context}
353
+
354
+ ## 用戶問題:
355
+
356
+ {query}
357
+
358
+ {ending}"""
359
+ else: # English
360
+ prompt = f"""{system_prompt}
361
+
362
+ ## Relevant Document Excerpts:
363
+
364
+ {context}
365
+
366
+ ## User Question:
367
+
368
+ {query}
369
+
370
+ {ending}"""
371
+
372
+ return prompt
373
+
374
+ def format_for_llm(
375
+ self,
376
+ query: str,
377
+ results: List[Dict],
378
+ system_prompt: Optional[str] = None,
379
+ document_type: str = "general"
380
+ ) -> str:
381
+ """
382
+ 一站式方法:格式化檢索結果並創建完整的 prompt
383
+
384
+ Args:
385
+ query: 用戶查詢
386
+ results: 檢索結果列表
387
+ system_prompt: 可選的系統提示詞
388
+ document_type: 文檔類型 ("paper", "cv", "general")
389
+
390
+ Returns:
391
+ 完整的 prompt 字符串
392
+ """
393
+ context = self.format_context(results, document_type=document_type)
394
+ return self.create_prompt(query, context, system_prompt, document_type)
395
+
src/retrievers/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 檢索器模組
3
+ """
4
+ from .base import BaseRetriever
5
+ from .bm25_retriever import BM25Retriever
6
+ from .vector_retriever import VectorRetriever
7
+ from .hybrid_search import HybridSearch
8
+ from .reranker import Reranker, RAGPipeline
9
+
10
+ __all__ = [
11
+ "BaseRetriever",
12
+ "BM25Retriever",
13
+ "VectorRetriever",
14
+ "HybridSearch",
15
+ "Reranker",
16
+ "RAGPipeline",
17
+ ]
src/retrievers/base.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 檢索器模組的抽象基類
3
+ """
4
+ from abc import ABC, abstractmethod
5
+ from typing import List, Dict, Optional
6
+
7
+
8
+ class BaseRetriever(ABC):
9
+ """檢索器的抽象基類"""
10
+
11
+ @abstractmethod
12
+ def retrieve(
13
+ self,
14
+ query: str,
15
+ top_k: int = 5,
16
+ metadata_filter: Optional[Dict] = None
17
+ ) -> List[Dict]:
18
+ """
19
+ 檢索相關文檔並返回帶有分數的結果。
20
+
21
+ Args:
22
+ query: 查詢文字
23
+ top_k: 返回前 k 個結果
24
+ metadata_filter: 可選的 metadata 過濾條件字典。
25
+ 例如: {"arxiv_id": "1234.5678"} 或 {"title": "Machine Learning"}
26
+ 支援多個條件,所有條件必須同時滿足(AND 邏輯)
27
+
28
+ Returns:
29
+ 相關文檔列表,每個文檔字典都應包含 "score" 鍵,
30
+ 且分數越高代表越相關。返回的結果會根據 metadata_filter 進行過濾。
31
+ """
32
+ pass
src/retrievers/bm25_retriever.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BM25 檢索器模組
3
+ """
4
+ from typing import List, Dict, Optional
5
+ from rank_bm25 import BM25Okapi
6
+ import re
7
+ from .base import BaseRetriever
8
+
9
+
10
+ class BM25Retriever(BaseRetriever):
11
+ """使用 BM25 演算法進行關鍵字檢索"""
12
+
13
+ def __init__(self, documents: List[Dict]):
14
+ """
15
+ 初始化 BM25 檢索器
16
+
17
+ Args:
18
+ documents: 文檔列表,每個文檔包含 "content" 和 "metadata"
19
+ """
20
+ self.documents = documents
21
+ self.texts = [doc["content"] for doc in documents]
22
+
23
+ # 對文字進行 tokenization(簡單的分詞)
24
+ tokenized_texts = [self._tokenize(text) for text in self.texts]
25
+
26
+ # 初始化 BM25
27
+ self.bm25 = BM25Okapi(tokenized_texts)
28
+
29
+ def _tokenize(self, text: str) -> List[str]:
30
+ """
31
+ 將文字轉換為 tokens(簡單的實作)
32
+
33
+ Args:
34
+ text: 輸入文字
35
+
36
+ Returns:
37
+ token 列表
38
+ """
39
+ # 轉為小寫並分割
40
+ text = text.lower()
41
+ # 使用正則表達式分割(保留字母和數字)
42
+ tokens = re.findall(r'\b\w+\b', text)
43
+ return tokens
44
+
45
+ def retrieve(
46
+ self,
47
+ query: str,
48
+ top_k: int = 5,
49
+ metadata_filter: Optional[Dict] = None
50
+ ) -> List[Dict]:
51
+ """
52
+ 檢索相關文檔,支援根據 metadata 進行過濾。
53
+
54
+ Args:
55
+ query: 查詢文字
56
+ top_k: 返回前 k 個結果
57
+ metadata_filter: 可選的 metadata 過濾條件字典。
58
+ 例如: {"arxiv_id": "1234.5678"} 只檢索特定論文的 chunks
59
+ 或 {"title": "Machine Learning"} 只檢索特定標題的論文
60
+ 支援多個條件,所有條件必須同時滿足(AND 邏輯)
61
+ 注意:BM25 的過濾是在檢索後進行的,所以可能會返回少於 top_k 的結果
62
+
63
+ Returns:
64
+ 相關文檔列表,每個包含 "content", "metadata", "score"
65
+ 結果會根據 metadata_filter 進行過濾
66
+ """
67
+ # Tokenize 查詢
68
+ tokenized_query = self._tokenize(query)
69
+
70
+ # 計算 BM25 分數
71
+ scores = self.bm25.get_scores(tokenized_query)
72
+
73
+ # 獲取所有結果並排序(先獲取更多結果以應對過濾後可能減少的情況)
74
+ # 如果沒有過濾條件,只需要 top_k 個;如果有過濾條件,需要更多候選結果
75
+ candidate_k = top_k * 3 if metadata_filter else top_k
76
+
77
+ # 獲取候選結果索引(按分數降序排列)
78
+ sorted_indices = sorted(
79
+ range(len(scores)),
80
+ key=lambda i: scores[i],
81
+ reverse=True
82
+ )[:candidate_k]
83
+
84
+ # 構建候選結果
85
+ candidate_results = []
86
+ for idx in sorted_indices:
87
+ candidate_results.append({
88
+ "content": self.documents[idx]["content"],
89
+ "metadata": self.documents[idx]["metadata"],
90
+ "score": float(scores[idx]),
91
+ })
92
+
93
+ # 如果提供了 metadata_filter,則進行過濾
94
+ if metadata_filter:
95
+ filtered_results = []
96
+ for result in candidate_results:
97
+ # 檢查該結果的 metadata 是否滿足所有過濾條件
98
+ metadata = result.get("metadata", {})
99
+ matches_all = True
100
+
101
+ for filter_key, filter_value in metadata_filter.items():
102
+ # 獲取文檔中對應的 metadata 值
103
+ doc_value = metadata.get(filter_key)
104
+
105
+ # 檢查是否匹配
106
+ # 支援精確匹配和部分匹配(如果 filter_value 是字串且 doc_value 也是字串)
107
+ if isinstance(filter_value, str) and isinstance(doc_value, str):
108
+ # 字串匹配:支援精確匹配或包含匹配
109
+ if filter_value.lower() not in doc_value.lower():
110
+ matches_all = False
111
+ break
112
+ else:
113
+ # 其他類型(數字、布林值等)使用精確匹配
114
+ if doc_value != filter_value:
115
+ matches_all = False
116
+ break
117
+
118
+ # 如果所有條件都滿足,則加入結果
119
+ if matches_all:
120
+ filtered_results.append(result)
121
+
122
+ # 返回過濾後的結果(最多 top_k 個)
123
+ return filtered_results[:top_k]
124
+ else:
125
+ # 沒有過濾條件,直接返回候選結果
126
+ return candidate_results
127
+
src/retrievers/hybrid_search.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid Search 模組:結合 BM25 和向量檢索
3
+ 支援兩種融合方法:加權求和(Weighted Sum)和倒數排名融合(RRF)
4
+ """
5
+ from typing import List, Dict, Optional, Literal
6
+ from .base import BaseRetriever
7
+ import numpy as np
8
+
9
+
10
+ class HybridSearch(BaseRetriever):
11
+ """結合稀疏和密集檢索的混合搜尋"""
12
+
13
+ def __init__(
14
+ self,
15
+ sparse_retriever: BaseRetriever,
16
+ dense_retriever: BaseRetriever,
17
+ sparse_weight: float = 0.4,
18
+ dense_weight: float = 0.6,
19
+ fusion_method: Literal["weighted_sum", "rrf"] = "rrf",
20
+ rrf_k: int = 60,
21
+ ):
22
+ """
23
+ 初始化 Hybrid Search
24
+
25
+ Args:
26
+ sparse_retriever: 稀疏檢索器 (例如 BM25)
27
+ dense_retriever: 密集檢索器 (例如向量檢索)
28
+ sparse_weight: 稀疏檢索分數的權重(僅用於 weighted_sum 方法)
29
+ dense_weight: 密集檢索分數的權重(僅用於 weighted_sum 方法)
30
+ fusion_method: 融合方法,可選 "weighted_sum" 或 "rrf"
31
+ - "weighted_sum": 加權求和,需要正規化分數並設置權重
32
+ - "rrf": 倒數排名融合(Reciprocal Rank Fusion),
33
+ 不需要分數正規化,對不同分數分佈更魯棒
34
+ rrf_k: RRF 方法中的常數 k,通常設為 60(僅用於 rrf 方法)
35
+ 較大的 k 值會讓排名較低的文檔獲得更多權重
36
+ """
37
+ self.sparse_retriever = sparse_retriever
38
+ self.dense_retriever = dense_retriever
39
+ self.fusion_method = fusion_method
40
+ self.rrf_k = rrf_k
41
+
42
+ # 僅在 weighted_sum 方法中使用權重
43
+ if fusion_method == "weighted_sum":
44
+ self.sparse_weight = sparse_weight
45
+ self.dense_weight = dense_weight
46
+
47
+ # 確保權重總和為 1
48
+ total_weight = sparse_weight + dense_weight
49
+ if abs(total_weight - 1.0) > 1e-6:
50
+ self.sparse_weight = sparse_weight / total_weight
51
+ self.dense_weight = dense_weight / total_weight
52
+
53
+ def _normalize_scores(self, results: List[Dict]) -> List[Dict]:
54
+ """
55
+ 將分數正規化到 [0, 1] 區間。
56
+ 僅用於 weighted_sum 方法。
57
+
58
+ Args:
59
+ results: 檢索結果列表,每個字典包含 'score'
60
+
61
+ Returns:
62
+ 帶有正規化分數的結果列表
63
+ """
64
+ scores = [res.get("score", 0.0) for res in results]
65
+ if not scores:
66
+ return results
67
+
68
+ scores_array = np.array(scores)
69
+ min_score = scores_array.min()
70
+ max_score = scores_array.max()
71
+
72
+ if max_score == min_score:
73
+ # 如果所有分數都相同,將它們設置為 1.0
74
+ normalized_scores = [1.0] * len(scores)
75
+ else:
76
+ normalized_scores = ((scores_array - min_score) / (max_score - min_score)).tolist()
77
+
78
+ for i, res in enumerate(results):
79
+ res["score"] = normalized_scores[i]
80
+
81
+ return results
82
+
83
+ def _get_doc_id(self, doc: Dict) -> str:
84
+ """
85
+ 從文檔中提取唯一標識符
86
+
87
+ Args:
88
+ doc: 文檔字典
89
+
90
+ Returns:
91
+ 文檔的唯一 ID
92
+ """
93
+ metadata = doc.get("metadata", {})
94
+ return f"{metadata.get('arxiv_id', 'unknown')}_{metadata.get('chunk_index', 0)}"
95
+
96
+ def _apply_rrf(
97
+ self,
98
+ sparse_results: List[Dict],
99
+ dense_results: List[Dict]
100
+ ) -> List[Dict]:
101
+ """
102
+ 應用倒數排名融合(Reciprocal Rank Fusion, RRF)方法
103
+
104
+ RRF 公式:RRF(d) = Σ(1 / (k + rank_i(d)))
105
+ 其中:
106
+ - d 是文檔
107
+ - rank_i(d) 是文檔在第 i 個檢索結果中的排名(從 1 開始)
108
+ - k 是常數(預設為 60)
109
+
110
+ RRF 的優點:
111
+ 1. 不需要分數正規化,對不同分數分佈的檢索器更魯棒
112
+ 2. 只依賴排名位置,不依賴分數值
113
+ 3. 自動處理分數分佈差異的問題
114
+
115
+ Args:
116
+ sparse_results: 稀疏檢索結果列表
117
+ dense_results: 密集檢索結果列表
118
+
119
+ Returns:
120
+ 融合後的結果列表,按 RRF 分數排序
121
+ """
122
+ # 建立文檔 ID 到 RRF 分數的映射
123
+ doc_to_rrf_score = {}
124
+
125
+ # 處理稀疏檢索結果(BM25)
126
+ for rank, result in enumerate(sparse_results, start=1):
127
+ doc_id = self._get_doc_id(result)
128
+ if doc_id not in doc_to_rrf_score:
129
+ doc_to_rrf_score[doc_id] = {
130
+ "doc": result,
131
+ "rrf_score": 0.0,
132
+ "sparse_rank": None,
133
+ "dense_rank": None
134
+ }
135
+ # 計算 RRF 貢獻:1 / (k + rank)
136
+ doc_to_rrf_score[doc_id]["rrf_score"] += 1.0 / (self.rrf_k + rank)
137
+ doc_to_rrf_score[doc_id]["sparse_rank"] = rank
138
+
139
+ # 處理密集檢索結果(向量)
140
+ for rank, result in enumerate(dense_results, start=1):
141
+ doc_id = self._get_doc_id(result)
142
+ if doc_id not in doc_to_rrf_score:
143
+ doc_to_rrf_score[doc_id] = {
144
+ "doc": result,
145
+ "rrf_score": 0.0,
146
+ "sparse_rank": None,
147
+ "dense_rank": None
148
+ }
149
+ # 計算 RRF 貢獻:1 / (k + rank)
150
+ doc_to_rrf_score[doc_id]["rrf_score"] += 1.0 / (self.rrf_k + rank)
151
+ doc_to_rrf_score[doc_id]["dense_rank"] = rank
152
+
153
+ # 構建結果列表
154
+ rrf_results = []
155
+ for doc_id, data in doc_to_rrf_score.items():
156
+ result = data["doc"].copy()
157
+ result["hybrid_score"] = data["rrf_score"]
158
+ result["rrf_score"] = data["rrf_score"]
159
+ result["sparse_rank"] = data["sparse_rank"]
160
+ result["dense_rank"] = data["dense_rank"]
161
+
162
+ # 從原始結果中獲取分數以供參考
163
+ if data["sparse_rank"] is not None:
164
+ # 從稀疏檢索結果中獲取原始分數
165
+ for sparse_res in sparse_results:
166
+ if self._get_doc_id(sparse_res) == doc_id:
167
+ result["sparse_score"] = sparse_res.get("score", 0.0)
168
+ break
169
+ else:
170
+ result["sparse_score"] = None
171
+
172
+ if data["dense_rank"] is not None:
173
+ # 從密集檢索結果中獲取原始分數
174
+ for dense_res in dense_results:
175
+ if self._get_doc_id(dense_res) == doc_id:
176
+ result["dense_score"] = dense_res.get("score", 0.0)
177
+ break
178
+ else:
179
+ result["dense_score"] = None
180
+
181
+ rrf_results.append(result)
182
+
183
+ # 按 RRF 分數從高到低排序
184
+ rrf_results.sort(key=lambda x: x["rrf_score"], reverse=True)
185
+
186
+ return rrf_results
187
+
188
+ def _apply_weighted_sum(
189
+ self,
190
+ sparse_results: List[Dict],
191
+ dense_results: List[Dict]
192
+ ) -> List[Dict]:
193
+ """
194
+ 應用加權求和(Weighted Sum)方法
195
+
196
+ 此方法需要:
197
+ 1. 正規化兩組分數到相同範圍
198
+ 2. 根據權重進行加權求和
199
+
200
+ Args:
201
+ sparse_results: 稀疏檢索結果列表
202
+ dense_results: 密集檢索結果列表
203
+
204
+ Returns:
205
+ 融合後的結果列表,按混合分數排序
206
+ """
207
+ # 正規化兩組分數
208
+ normalized_sparse = self._normalize_scores(sparse_results)
209
+ normalized_dense = self._normalize_scores(dense_results)
210
+
211
+ # 結合分數
212
+ doc_to_scores = {}
213
+
214
+ # 處理稀疏檢索結果
215
+ for res in normalized_sparse:
216
+ doc_id = self._get_doc_id(res)
217
+ if doc_id not in doc_to_scores:
218
+ doc_to_scores[doc_id] = {"doc": res, "sparse": 0.0, "dense": 0.0}
219
+ doc_to_scores[doc_id]["sparse"] = res["score"]
220
+
221
+ # 處理密集檢索結果
222
+ for res in normalized_dense:
223
+ doc_id = self._get_doc_id(res)
224
+ if doc_id not in doc_to_scores:
225
+ doc_to_scores[doc_id] = {"doc": res, "sparse": 0.0, "dense": 0.0}
226
+ doc_to_scores[doc_id]["dense"] = res["score"]
227
+
228
+ # 計算混合分數並排序
229
+ hybrid_results = []
230
+ for doc_id, scores in doc_to_scores.items():
231
+ hybrid_score = (
232
+ self.sparse_weight * scores["sparse"] +
233
+ self.dense_weight * scores["dense"]
234
+ )
235
+
236
+ result = scores["doc"].copy()
237
+ result["hybrid_score"] = hybrid_score
238
+ result["sparse_score"] = scores["sparse"]
239
+ result["dense_score"] = scores["dense"]
240
+ hybrid_results.append(result)
241
+
242
+ # 按混合分數從高到低排序
243
+ hybrid_results.sort(key=lambda x: x["hybrid_score"], reverse=True)
244
+
245
+ return hybrid_results
246
+
247
+ def retrieve(
248
+ self,
249
+ query: str,
250
+ top_k: int = 5,
251
+ metadata_filter: Optional[Dict] = None
252
+ ) -> List[Dict]:
253
+ """
254
+ 執行混合搜尋,支援根據 metadata 進行過濾
255
+
256
+ Args:
257
+ query: 查詢文字
258
+ top_k: 返回前 k 個結果
259
+ metadata_filter: 可選的 metadata 過濾條件字典。
260
+ 例如: {"arxiv_id": "1234.5678"} 只檢索特定論文的 chunks
261
+ 或 {"title": "Machine Learning"} 只檢索特定標題的論文
262
+ 支援多個條件,所有條件必須同時滿足(AND 邏輯)
263
+ 此過濾條件會傳遞給底層的稀疏和密集檢索器
264
+
265
+ Returns:
266
+ 相關文檔列表,每個包含 "content", "metadata", "hybrid_score"
267
+ 結果會根據 metadata_filter 進行過濾
268
+
269
+ 根據 fusion_method 的不同,返回的結果會包含不同的分數欄位:
270
+ - RRF 方法:包含 "rrf_score", "sparse_rank", "dense_rank"
271
+ - Weighted Sum 方法:包含 "sparse_score", "dense_score"
272
+ """
273
+ # 1. 從兩個檢索器獲取結果(請求更多結果以確保覆蓋率)
274
+ # 將 metadata_filter 傳遞給底層檢索器
275
+ sparse_results = self.sparse_retriever.retrieve(
276
+ query,
277
+ top_k=top_k * 2,
278
+ metadata_filter=metadata_filter
279
+ )
280
+ dense_results = self.dense_retriever.retrieve(
281
+ query,
282
+ top_k=top_k * 2,
283
+ metadata_filter=metadata_filter
284
+ )
285
+
286
+ # 2. 根據選擇的融合方法進行結果融合
287
+ if self.fusion_method == "rrf":
288
+ # 使用 RRF(倒數排名融合)方法
289
+ # RRF 不需要分數正規化,直接基於排名進行融合
290
+ hybrid_results = self._apply_rrf(sparse_results, dense_results)
291
+ else:
292
+ # 使用加權求和方法
293
+ # 需要先正規化分數,然後根據權重進行加權求和
294
+ hybrid_results = self._apply_weighted_sum(sparse_results, dense_results)
295
+
296
+ # 3. 返回前 top_k 個結果
297
+ return hybrid_results[:top_k]
298
+
src/retrievers/reranker.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 重排序模組:使用 Cross-Encoder 進行精準重排
3
+ """
4
+ from typing import List, Dict, Optional, Tuple
5
+ from sentence_transformers import CrossEncoder
6
+ import time
7
+ import logging
8
+
9
+ # 嘗試導入 torch 來檢測可用的設備
10
+ try:
11
+ import torch
12
+ TORCH_AVAILABLE = True
13
+ except ImportError:
14
+ TORCH_AVAILABLE = False
15
+
16
+ # 配置日志
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def get_device() -> str:
22
+ """
23
+ 自動檢測並返回最佳可用的設備
24
+
25
+ Returns:
26
+ 設備名稱: 'mps' (macOS GPU), 'cuda' (NVIDIA GPU), 或 'cpu'
27
+ """
28
+ if not TORCH_AVAILABLE:
29
+ return 'cpu'
30
+
31
+ # 優先順序: MPS (macOS) > CUDA (NVIDIA) > CPU
32
+ if torch.backends.mps.is_available():
33
+ return 'mps'
34
+ elif torch.cuda.is_available():
35
+ return 'cuda'
36
+ else:
37
+ return 'cpu'
38
+
39
+
40
+ class Reranker:
41
+ """重排序組件:使用 Cross-Encoder 進行精準重排"""
42
+
43
+ def __init__(
44
+ self,
45
+ model_name: str = "BAAI/bge-reranker-base",
46
+ device: str = None,
47
+ max_length: int = 512,
48
+ batch_size: int = 32,
49
+ enable_cache: bool = True
50
+ ):
51
+ """
52
+ 初始化 Cross-Encoder 模型
53
+
54
+ Args:
55
+ model_name: Cross-Encoder 模型名稱
56
+ device: 設備名稱 ('cuda', 'cpu', 'mps')
57
+ max_length: 最大 token 長度(模型限制)
58
+ batch_size: 批處理大小,用於優化內存使用
59
+ enable_cache: 是否啟用模型緩存
60
+ """
61
+ try:
62
+ # 自動檢測設備(如果未指定)
63
+ if device is None:
64
+ device = get_device()
65
+
66
+ device_name_map = {
67
+ 'mps': 'MPS (macOS GPU)',
68
+ 'cuda': 'CUDA (NVIDIA GPU)',
69
+ 'cpu': 'CPU'
70
+ }
71
+ device_display = device_name_map.get(device, device)
72
+
73
+ self.model = CrossEncoder(
74
+ model_name,
75
+ device=device,
76
+ max_length=max_length
77
+ )
78
+ self.max_length = max_length
79
+ self.batch_size = batch_size
80
+ self.model_name = model_name
81
+ logger.info(f"✅ 重排模型 {model_name} 已載入 (device: {device_display})")
82
+ except Exception as e:
83
+ logger.error(f"❌ 模型載入失敗: {e}")
84
+ raise
85
+
86
+ def _truncate_text(self, text: str, max_chars: int = 2000) -> str:
87
+ """
88
+ 截斷過長的文本(粗略估計,避免超過 token 限制)
89
+
90
+ Args:
91
+ text: 原始文本
92
+ max_chars: 最大字符數(保守估計,約 500 tokens)
93
+
94
+ Returns:
95
+ 截斷後的文本
96
+ """
97
+ if len(text) <= max_chars:
98
+ return text
99
+ # 截斷並添加省略號
100
+ return text[:max_chars - 3] + "..."
101
+
102
+ def _prepare_pairs(
103
+ self,
104
+ query: str,
105
+ documents: List[Dict]
106
+ ) -> List[Tuple[str, str]]:
107
+ """
108
+ 準備 (query, document) 配對,處理文本長度
109
+
110
+ Args:
111
+ query: 查詢文本
112
+ documents: 文檔列表
113
+
114
+ Returns:
115
+ (query, content) 配對列表
116
+ """
117
+ pairs = []
118
+ truncated_indices = [] # 記錄哪些文檔被截斷了
119
+
120
+ # 粗略估計:每個字符約 0.25 tokens,為 query 預留空間
121
+ max_doc_chars = int((self.max_length * 0.7) - len(query))
122
+
123
+ for i, doc in enumerate(documents):
124
+ content = doc.get("content", "")
125
+ original_length = len(content)
126
+
127
+ # 如果內容過長,進行截斷
128
+ if len(content) > max_doc_chars:
129
+ content = self._truncate_text(content, max_doc_chars)
130
+ truncated_indices.append(i)
131
+
132
+ pairs.append([query, content])
133
+
134
+ if truncated_indices:
135
+ logger.warning(
136
+ f"⚠️ 有 {len(truncated_indices)} 個文檔因過長被截斷 "
137
+ f"(最大長度: {max_doc_chars} 字符)"
138
+ )
139
+
140
+ return pairs
141
+
142
+ def rerank(
143
+ self,
144
+ query: str,
145
+ documents: List[Dict],
146
+ top_k: int = 5,
147
+ preserve_original_scores: bool = True
148
+ ) -> List[Dict]:
149
+ """
150
+ 執行精準重排
151
+
152
+ Args:
153
+ query: 查詢文本
154
+ documents: 文檔列表,每個應包含 "content" 和可選的 "hybrid_score"
155
+ top_k: 返回前 k 個結果
156
+ preserve_original_scores: 是否保留原始分數(hybrid_score)
157
+
158
+ Returns:
159
+ 重排後的文檔列表,按 rerank_score 降序排列
160
+ """
161
+ if not documents:
162
+ logger.warning("⚠️ 文檔列表為空,返回空結果")
163
+ return []
164
+
165
+ if not query or not query.strip():
166
+ logger.warning("⚠️ 查詢為空,返回原始文檔順序")
167
+ return documents[:top_k]
168
+
169
+ start_time = time.time()
170
+ logger.info(f"🔄 開始重排 {len(documents)} 個文檔...")
171
+
172
+ try:
173
+ # 1. 準備配對
174
+ pairs = self._prepare_pairs(query, documents)
175
+
176
+ # 2. 批處理計算分數(優化內存使用)
177
+ scores = []
178
+ for i in range(0, len(pairs), self.batch_size):
179
+ batch_pairs = pairs[i:i + self.batch_size]
180
+ batch_scores = self.model.predict(batch_pairs)
181
+ scores.extend(batch_scores.tolist() if hasattr(batch_scores, 'tolist') else batch_scores)
182
+
183
+ # 3. 更新文檔分數
184
+ for i, doc in enumerate(documents):
185
+ doc = doc.copy() # 避免修改原始文檔
186
+ doc["rerank_score"] = float(scores[i])
187
+
188
+ # 保留原始分數供參考
189
+ if preserve_original_scores:
190
+ if "hybrid_score" not in doc:
191
+ # 如果沒有 hybrid_score,嘗試使用其他分數
192
+ doc["original_score"] = doc.get("score", 0.0)
193
+
194
+ documents[i] = doc
195
+
196
+ # 4. 根據 rerank_score 重新排序
197
+ reranked_docs = sorted(
198
+ documents,
199
+ key=lambda x: x.get("rerank_score", float('-inf')),
200
+ reverse=True
201
+ )
202
+
203
+ # 5. 統計資訊
204
+ elapsed_time = time.time() - start_time
205
+ avg_score = sum(scores) / len(scores) if scores else 0.0
206
+ max_score = max(scores) if scores else 0.0
207
+ min_score = min(scores) if scores else 0.0
208
+
209
+ logger.info(
210
+ f"✅ 重排完成 (耗時: {elapsed_time:.2f}s, "
211
+ f"平均分數: {avg_score:.4f}, "
212
+ f"範圍: [{min_score:.4f}, {max_score:.4f}])"
213
+ )
214
+
215
+ return reranked_docs[:top_k]
216
+
217
+ except Exception as e:
218
+ logger.error(f"❌ 重排過程出錯: {e}")
219
+ # 降級策略:返回原始順序的前 top_k 個
220
+ logger.warning("⚠️ 使用降級策略:返回原始順序")
221
+ return documents[:top_k]
222
+
223
+
224
+ class RAGPipeline:
225
+ """協調管線:管理完整的 RAG 流程(召回 + 重排)"""
226
+
227
+ def __init__(
228
+ self,
229
+ hybrid_search,
230
+ reranker,
231
+ recall_k: int = 25,
232
+ adaptive_recall: bool = True,
233
+ min_recall_k: int = 10,
234
+ max_recall_k: int = 50
235
+ ):
236
+ """
237
+ 初始化 RAG 管線
238
+
239
+ Args:
240
+ hybrid_search: HybridSearch 實例
241
+ reranker: Reranker 實例
242
+ recall_k: 第一階段召回的數量(預設值)
243
+ adaptive_recall: 是否根據查詢動態調整 recall_k
244
+ min_recall_k: 最小召回數量
245
+ max_recall_k: 最大召回數量
246
+ """
247
+ self.hybrid_search = hybrid_search
248
+ self.reranker = reranker
249
+ self.base_recall_k = recall_k
250
+ self.adaptive_recall = adaptive_recall
251
+ self.min_recall_k = min_recall_k
252
+ self.max_recall_k = max_recall_k
253
+
254
+ # 性能統計
255
+ self.stats = {
256
+ "total_queries": 0,
257
+ "avg_recall_time": 0.0,
258
+ "avg_rerank_time": 0.0,
259
+ "avg_total_time": 0.0
260
+ }
261
+
262
+ def _calculate_adaptive_recall_k(self, query: str) -> int:
263
+ """
264
+ 根據查詢複雜度動態計算 recall_k
265
+
266
+ Args:
267
+ query: 查詢文本
268
+
269
+ Returns:
270
+ 調整後的 recall_k
271
+ """
272
+ if not self.adaptive_recall:
273
+ return self.base_recall_k
274
+
275
+ # 簡單啟發式:根據查詢長度和關鍵詞數量調整
276
+ query_length = len(query.split())
277
+ keyword_count = len(set(query.lower().split()))
278
+
279
+ # 複雜查詢需要更多候選
280
+ if query_length > 10 or keyword_count > 5:
281
+ recall_k = min(self.base_recall_k * 2, self.max_recall_k)
282
+ elif query_length < 3:
283
+ recall_k = max(self.base_recall_k // 2, self.min_recall_k)
284
+ else:
285
+ recall_k = self.base_recall_k
286
+
287
+ return recall_k
288
+
289
+ def query(
290
+ self,
291
+ text: str,
292
+ top_k: int = 5,
293
+ metadata_filter: Optional[Dict] = None,
294
+ enable_rerank: bool = True,
295
+ return_stats: bool = False
296
+ ) -> List[Dict]:
297
+ """
298
+ 執行完整的搜尋流程
299
+
300
+ Args:
301
+ text: 查詢文本
302
+ top_k: 最終返回的結果數量
303
+ metadata_filter: 可選的 metadata 過濾條件
304
+ enable_rerank: 是否啟用重排序(可選,用於性能測試)
305
+ return_stats: 是否返回���能統計資訊
306
+
307
+ Returns:
308
+ 相關文檔列表,如果 return_stats=True,則返回 (results, stats) 元組
309
+ """
310
+ if not text or not text.strip():
311
+ logger.warning("⚠️ 查詢為空")
312
+ return []
313
+
314
+ total_start = time.time()
315
+ self.stats["total_queries"] += 1
316
+
317
+ # 動態計算 recall_k
318
+ recall_k = self._calculate_adaptive_recall_k(text)
319
+ logger.info(
320
+ f"🔍 搜尋中: '{text[:50]}...' "
321
+ f"(召回階段: {recall_k} 筆, 最終返回: {top_k} 筆)"
322
+ )
323
+
324
+ try:
325
+ # 第一階段:混合搜尋(召回階段)
326
+ recall_start = time.time()
327
+ initial_results = self.hybrid_search.retrieve(
328
+ query=text,
329
+ top_k=recall_k,
330
+ metadata_filter=metadata_filter
331
+ )
332
+ recall_time = time.time() - recall_start
333
+
334
+ if not initial_results:
335
+ logger.warning("⚠️ 召回階段未找到任何結果")
336
+ return []
337
+
338
+ logger.info(
339
+ f"✅ 召回階段完成: 找到 {len(initial_results)} 個候選 "
340
+ f"(耗時: {recall_time:.2f}s)"
341
+ )
342
+
343
+ # 第二階段:重排序(精篩階段)
344
+ if enable_rerank and len(initial_results) > top_k:
345
+ rerank_start = time.time()
346
+ final_results = self.reranker.rerank(
347
+ query=text,
348
+ documents=initial_results,
349
+ top_k=top_k
350
+ )
351
+ rerank_time = time.time() - rerank_start
352
+
353
+ logger.info(
354
+ f"✅ 重排階段完成: 從 {len(initial_results)} 個候選中選出 "
355
+ f"{len(final_results)} 個結果 (耗時: {rerank_time:.2f}s)"
356
+ )
357
+ else:
358
+ # 跳過重排序(用於性能測試或候選數較少時)
359
+ final_results = initial_results[:top_k]
360
+ rerank_time = 0.0
361
+ logger.info("⏭️ 跳過重排序階段(候選數不足或已禁用)")
362
+
363
+ # 更新統計資訊
364
+ total_time = time.time() - total_start
365
+ self._update_stats(recall_time, rerank_time, total_time)
366
+
367
+ # 添加性能資訊到結果(可選)
368
+ if return_stats:
369
+ stats = {
370
+ "recall_time": recall_time,
371
+ "rerank_time": rerank_time,
372
+ "total_time": total_time,
373
+ "recall_k": recall_k,
374
+ "candidates_found": len(initial_results),
375
+ "final_results": len(final_results)
376
+ }
377
+ return final_results, stats
378
+
379
+ return final_results
380
+
381
+ except Exception as e:
382
+ logger.error(f"❌ 查詢過程出錯: {e}")
383
+ # 降級策略:嘗試只使用召回階段
384
+ try:
385
+ logger.warning("⚠️ 嘗試降級策略:僅使用召回結果")
386
+ return self.hybrid_search.retrieve(text, top_k=top_k, metadata_filter=metadata_filter)
387
+ except Exception as e2:
388
+ logger.error(f"❌ 降級策略也失敗: {e2}")
389
+ return []
390
+
391
+ def _update_stats(self, recall_time: float, rerank_time: float, total_time: float):
392
+ """更新性能統計資訊"""
393
+ n = self.stats["total_queries"]
394
+ self.stats["avg_recall_time"] = (
395
+ (self.stats["avg_recall_time"] * (n - 1) + recall_time) / n
396
+ )
397
+ self.stats["avg_rerank_time"] = (
398
+ (self.stats["avg_rerank_time"] * (n - 1) + rerank_time) / n
399
+ )
400
+ self.stats["avg_total_time"] = (
401
+ (self.stats["avg_total_time"] * (n - 1) + total_time) / n
402
+ )
403
+
404
+ def get_stats(self) -> Dict:
405
+ """獲取性能統計資訊"""
406
+ return self.stats.copy()
407
+
408
+ def reset_stats(self):
409
+ """重置統計資訊"""
410
+ self.stats = {
411
+ "total_queries": 0,
412
+ "avg_recall_time": 0.0,
413
+ "avg_rerank_time": 0.0,
414
+ "avg_total_time": 0.0
415
+ }
416
+
417
+ def format_results_for_llm(
418
+ self,
419
+ results: List[Dict],
420
+ format_style: str = "detailed"
421
+ ) -> str:
422
+ """
423
+ 格式化檢索結果供 LLM 使用(需要導入 PromptFormatter)
424
+
425
+ Args:
426
+ results: 檢索結果列表
427
+ format_style: 格式風格 ("detailed", "simple", "minimal")
428
+
429
+ Returns:
430
+ 格式化後的上下文字符串
431
+ """
432
+ try:
433
+ from ..prompt_formatter import PromptFormatter
434
+ formatter = PromptFormatter(format_style=format_style)
435
+ return formatter.format_context(results)
436
+ except ImportError:
437
+ # 如果無法導入,使用簡單格式
438
+ formatted_parts = []
439
+ for i, result in enumerate(results, 1):
440
+ metadata = result.get("metadata", {})
441
+ content = result.get("content", "")
442
+ arxiv_id = metadata.get('arxiv_id', 'N/A')
443
+ title = metadata.get('title', 'N/A')
444
+ formatted_parts.append(
445
+ f"[來源 {i}: {title} (arXiv:{arxiv_id})]\n{content}\n"
446
+ )
447
+ return "\n" + "="*60 + "\n".join(formatted_parts)
448
+
src/retrievers/vector_retriever.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 向量檢索器模組:使用 embedding 和向量資料庫進行語義檢索
3
+
4
+ 支援兩種初始化方式:
5
+ 1. 自動初始化 embeddings(預設):根據參數創建新的 embedding 模型
6
+ 2. 使用外部 embeddings:接收已初始化的 embedding 模型(可與 DocumentProcessor 共用)
7
+ """
8
+ from typing import List, Dict, Optional, Any
9
+ from langchain_community.vectorstores import Chroma
10
+ from langchain_core.documents import Document
11
+ import os
12
+ from .base import BaseRetriever
13
+
14
+ # 嘗試導入 HuggingFaceEmbeddings(免費模型)
15
+ try:
16
+ from langchain_community.embeddings import HuggingFaceEmbeddings
17
+ except ImportError:
18
+ try:
19
+ from langchain_huggingface import HuggingFaceEmbeddings
20
+ except ImportError:
21
+ raise ImportError("需要安裝 langchain-community 或 langchain-huggingface 才能使用 Hugging Face embeddings")
22
+
23
+ # 導入 torch 來檢測可用的設備
24
+ try:
25
+ import torch
26
+ TORCH_AVAILABLE = True
27
+ except ImportError:
28
+ TORCH_AVAILABLE = False
29
+
30
+
31
+ def get_device() -> str:
32
+ """
33
+ 自動檢測並返回最佳可用的設備
34
+
35
+ Returns:
36
+ 設備名稱: 'mps' (macOS GPU), 'cuda' (NVIDIA GPU), 或 'cpu'
37
+ """
38
+ if not TORCH_AVAILABLE:
39
+ return 'cpu'
40
+
41
+ # 優先順序: MPS (macOS) > CUDA (NVIDIA) > CPU
42
+ if torch.backends.mps.is_available():
43
+ return 'mps'
44
+ elif torch.cuda.is_available():
45
+ return 'cuda'
46
+ else:
47
+ return 'cpu'
48
+
49
+
50
+ class VectorRetriever(BaseRetriever):
51
+ """使用向量檢索進行語義搜尋"""
52
+
53
+ def __init__(
54
+ self,
55
+ documents: List[Dict],
56
+ embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
57
+ persist_directory: Optional[str] = "./chroma_db",
58
+ hf_cache_dir: Optional[str] = None,
59
+ device: Optional[str] = None,
60
+ embeddings: Optional[Any] = None # 可選:外部傳入的 embedding 模型(優先使用)
61
+ ):
62
+ """
63
+ 初始化向量檢索器(使用 Hugging Face embeddings)
64
+
65
+ Args:
66
+ documents: 文檔列表,每個文檔包含 "content" 和 "metadata"
67
+ embedding_model: Hugging Face embedding 模型名稱(預設: "sentence-transformers/all-MiniLM-L6-v2")
68
+ 僅在 embeddings=None 時使用
69
+ persist_directory: Chroma 資料庫持久化目錄
70
+ hf_cache_dir: Hugging Face 模型緩存目錄(例如外接硬碟路徑)
71
+ 如果為 None,則使用環境變數 HF_HOME 或預設位置 ~/.cache/huggingface/
72
+ 僅在 embeddings=None 時使用
73
+ device: 設備名稱 ('mps', 'cuda', 'cpu'),如果為 None 則自動檢測最佳設備
74
+ 僅在 embeddings=None 時使用
75
+ embeddings: 可選的外部 embedding 模型物件
76
+ 如果提供,將優先使用此模型,忽略其他參數(embedding_model, hf_cache_dir, device)
77
+ 這允許與 DocumentProcessor 共用同一個 embedding 模型實例
78
+ 優點:
79
+ - 節省內存(只加載一次模型)
80
+ - 節省時間(避免重複初始化)
81
+ - 確保一致性(分塊和檢索使用相同的模型)
82
+ """
83
+ # 優先使用傳入的共用模型
84
+ if embeddings is not None:
85
+ self.embeddings = embeddings
86
+ print("✓ 使用外部傳入的 embeddings 模型(與 DocumentProcessor 共用)")
87
+ else:
88
+ # 若無傳入,則執行原有的初始化邏輯
89
+ print(f"使用 Hugging Face embedding 模型: {embedding_model}")
90
+
91
+ # 設置 Hugging Face 緩存目錄
92
+ if hf_cache_dir:
93
+ # 如果指定了緩存目錄,設置環境變數
94
+ os.environ['HF_HOME'] = hf_cache_dir
95
+ os.environ['TRANSFORMERS_CACHE'] = hf_cache_dir
96
+ print(f"模型將存儲在: {hf_cache_dir}")
97
+ else:
98
+ # 檢查是否已經設置了環境變數
99
+ default_cache = os.path.expanduser("~/.cache/huggingface")
100
+ current_cache = os.getenv('HF_HOME', default_cache)
101
+ print(f"模型緩存位置: {current_cache}")
102
+ print("提示: 可以通過設置 hf_cache_dir 參數或環境變數 HF_HOME 來指定外接硬碟路徑")
103
+
104
+ # 自動檢測或使用指定的設備
105
+ if device is None:
106
+ device = get_device()
107
+
108
+ device_name_map = {
109
+ 'mps': 'MPS (macOS GPU)',
110
+ 'cuda': 'CUDA (NVIDIA GPU)',
111
+ 'cpu': 'CPU'
112
+ }
113
+ print(f"使用設備: {device_name_map.get(device, device)}")
114
+ print("首次使用時會下載模型,請稍候...")
115
+
116
+ # 構建 model_kwargs,包含緩存目錄和設備
117
+ model_kwargs = {'device': device}
118
+ if hf_cache_dir:
119
+ model_kwargs['cache_dir'] = hf_cache_dir
120
+
121
+ self.embeddings = HuggingFaceEmbeddings(
122
+ model_name=embedding_model,
123
+ model_kwargs=model_kwargs,
124
+ encode_kwargs={'normalize_embeddings': True} # 正規化 embeddings 以提升效果
125
+ )
126
+
127
+ # 將文檔轉換為 LangChain Document 格式
128
+ # 需要將 metadata 中的列表轉換為字串,因為 ChromaDB 不接受列表類型
129
+ def sanitize_metadata(metadata: Dict) -> Dict:
130
+ """將 metadata 中的列表轉換為字串,以符合 ChromaDB 的要求"""
131
+ sanitized = {}
132
+ for key, value in metadata.items():
133
+ if isinstance(value, list):
134
+ # 將列表轉換為逗號分隔的字串
135
+ sanitized[key] = ", ".join(str(v) for v in value)
136
+ elif isinstance(value, (dict, set)):
137
+ # 將字典或集合轉換為字串
138
+ sanitized[key] = str(value)
139
+ else:
140
+ # 其他類型(str, int, float, bool, None)直接保留
141
+ sanitized[key] = value
142
+ return sanitized
143
+
144
+ langchain_docs = [
145
+ Document(
146
+ page_content=doc["content"],
147
+ metadata=sanitize_metadata(doc["metadata"])
148
+ )
149
+ for doc in documents
150
+ ]
151
+
152
+ # 創建向量資料庫
153
+ self.vectorstore = Chroma.from_documents(
154
+ documents=langchain_docs,
155
+ embedding=self.embeddings,
156
+ persist_directory=persist_directory
157
+ )
158
+
159
+ # 創建 retriever
160
+ self.retriever = self.vectorstore.as_retriever()
161
+
162
+ def retrieve(
163
+ self,
164
+ query: str,
165
+ top_k: int = 5,
166
+ metadata_filter: Optional[Dict] = None
167
+ ) -> List[Dict]:
168
+ """
169
+ 檢索相關文檔,並返回標準化的相似度分數(越高越好)。
170
+ 支援根據 metadata 進行過濾。
171
+
172
+ Args:
173
+ query: 查詢文字
174
+ top_k: 返回前 k 個結果
175
+ metadata_filter: 可選的 metadata 過濾條件字典。
176
+ 例如: {"arxiv_id": "1234.5678"} 只檢索特定論文的 chunks
177
+ 或 {"title": "Machine Learning"} 只檢索特定標題的論文
178
+ 支援多個條件,所有條件必須同時滿足(AND 邏輯)
179
+ 注意:ChromaDB 的 where 條件支援精確匹配,不支援部分匹配
180
+
181
+ Returns:
182
+ 相關文檔列表,每個包含 "content", "metadata", 和 "score"
183
+ 結果會根據 metadata_filter 進行過濾
184
+ """
185
+ # 構建過濾條件
186
+ # 如果提供了 metadata_filter,先獲取更多結果,然後在 Python 中進行過濾
187
+ # 這是因為 LangChain ChromaDB 的 similarity_search_with_score 方法
188
+ # 對 filter 參數的支援可能因版本而異
189
+ if metadata_filter:
190
+ # 獲取更多結果以確保有足夠的候選進行過濾
191
+ results_with_scores = self.vectorstore.similarity_search_with_score(
192
+ query,
193
+ k=top_k * 10 # 獲取更多結果
194
+ )
195
+
196
+ # 在 Python 中進行過濾
197
+ filtered_results = []
198
+ for doc, distance_score in results_with_scores:
199
+ metadata = doc.metadata
200
+ matches = True
201
+
202
+ for key, value in metadata_filter.items():
203
+ doc_value = metadata.get(key)
204
+
205
+ # 檢查是否匹配
206
+ if isinstance(value, dict):
207
+ # 支援運算符格式(例如 {"$eq": "value"})
208
+ if "$eq" in value:
209
+ if doc_value != value["$eq"]:
210
+ matches = False
211
+ break
212
+ else:
213
+ # 其他運算符可以在此擴展
214
+ matches = False
215
+ break
216
+ elif isinstance(value, str) and isinstance(doc_value, str):
217
+ # 字串匹配:支援部分匹配(包含)
218
+ if value.lower() not in doc_value.lower():
219
+ matches = False
220
+ break
221
+ else:
222
+ # 其他類型使用精確匹配
223
+ if doc_value != value:
224
+ matches = False
225
+ break
226
+
227
+ if matches:
228
+ filtered_results.append((doc, distance_score))
229
+
230
+ # 只保留前 top_k 個結果
231
+ results_with_scores = filtered_results[:top_k]
232
+ else:
233
+ # 沒有過濾條件,直接獲取結果
234
+ results_with_scores = self.vectorstore.similarity_search_with_score(
235
+ query,
236
+ k=top_k
237
+ )
238
+
239
+ # 構建結果並轉換分數
240
+ results = []
241
+ for doc, distance_score in results_with_scores:
242
+ # 因為 embedding 已正規化,L2 距離的平方為 2 - 2 * cos_sim
243
+ # -> cos_sim = 1 - (distance^2 / 2)
244
+ # 分數範圍在 [0, 1] 之間,越高越相似
245
+ similarity_score = 1 - (distance_score**2 / 2)
246
+
247
+ results.append({
248
+ "content": doc.page_content,
249
+ "metadata": doc.metadata,
250
+ "score": float(similarity_score),
251
+ })
252
+
253
+ return results
254
+
src/step_back_rag.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step-back Prompting 雙軌 RAG:結合具體事實與抽象原理
3
+ 使用 Step-back Prompting 技術,同時檢索具體事實和抽象原理,提升回答質量
4
+ """
5
+ from typing import List, Dict, Optional
6
+ from .retrievers.reranker import RAGPipeline
7
+ from .retrievers.vector_retriever import VectorRetriever
8
+ from .prompt_formatter import PromptFormatter
9
+ from .llm_integration import OllamaLLM
10
+ import time
11
+ import logging
12
+ import hashlib
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class StepBackRAG:
19
+ """使用 Step-back Prompting 的雙軌 RAG 系統"""
20
+
21
+ def __init__(
22
+ self,
23
+ rag_pipeline: RAGPipeline,
24
+ vector_retriever: VectorRetriever,
25
+ llm: OllamaLLM,
26
+ step_back_temperature: float = 0.3, # 生成抽象問題時使用較低溫度
27
+ answer_temperature: float = 0.7,
28
+ enable_parallel: bool = True
29
+ ):
30
+ """
31
+ 初始化 Step-back RAG
32
+
33
+ Args:
34
+ rag_pipeline: RAG 管線實例(用於最終答案生成)
35
+ vector_retriever: 向量檢索器
36
+ llm: LLM 實例
37
+ step_back_temperature: 生成抽象問題的溫度(較低,更穩定)
38
+ answer_temperature: 生成答案的溫度
39
+ enable_parallel: 是否並行執行雙軌檢索
40
+ """
41
+ self.rag_pipeline = rag_pipeline
42
+ self.vector_retriever = vector_retriever
43
+ self.llm = llm
44
+ self.step_back_temperature = step_back_temperature
45
+ self.answer_temperature = answer_temperature
46
+ self.enable_parallel = enable_parallel
47
+
48
+ def _generate_step_back_question(self, question: str) -> str:
49
+ """
50
+ 生成 Step-back 抽象問題
51
+
52
+ Args:
53
+ question: 原始具體問題
54
+
55
+ Returns:
56
+ 抽象問題
57
+ """
58
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
59
+
60
+ if is_chinese:
61
+ prompt = f"""你是一個資深專家。請將以下具體問題轉換為一個更抽象、更基礎的原理性問題。
62
+ 這個抽象問題應該幫助理解該領域的基礎概念和原理,而不是直接回答具體問題。
63
+
64
+ 具體問題: {question}
65
+
66
+ 請生成一個抽象問題,用於檢索相關的原理和背景知識:
67
+ """
68
+ else:
69
+ prompt = f"""You are a senior expert. Please convert the following specific question into a more abstract, fundamental question about principles and concepts.
70
+ This abstract question should help understand the basic concepts and principles in this field, rather than directly answering the specific question.
71
+
72
+ Specific question: {question}
73
+
74
+ Please generate an abstract question for retrieving relevant principles and background knowledge:
75
+ """
76
+
77
+ try:
78
+ abstract_question = self.llm.generate(
79
+ prompt=prompt,
80
+ temperature=self.step_back_temperature,
81
+ max_tokens=200
82
+ )
83
+
84
+ abstract_question = abstract_question.strip()
85
+
86
+ if not abstract_question:
87
+ logger.warning("⚠️ 生成的抽象問題為空,使用原始問題")
88
+ return question
89
+
90
+ logger.info(f"✅ 生成抽象問題: '{abstract_question}'")
91
+ return abstract_question
92
+
93
+ except Exception as e:
94
+ logger.error(f"⚠️ 生成抽象問題時出錯: {e}")
95
+ return question
96
+
97
+ def _get_doc_id(self, doc: Dict) -> str:
98
+ """
99
+ 生成文檔的唯一標識符
100
+
101
+ Args:
102
+ doc: 文檔字典
103
+
104
+ Returns:
105
+ 唯一 ID
106
+ """
107
+ metadata = doc.get("metadata", {})
108
+ content = doc.get("content", "")
109
+
110
+ if "arxiv_id" in metadata and "chunk_index" in metadata:
111
+ return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
112
+ elif "file_path" in metadata and "chunk_index" in metadata:
113
+ return f"{metadata['file_path']}_{metadata['chunk_index']}"
114
+ else:
115
+ content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
116
+ return f"doc_{content_hash}"
117
+
118
+ def _retrieve_direct(self, question: str, top_k: int, metadata_filter: Optional[Dict] = None) -> List[Dict]:
119
+ """直接檢索原始問題(具體事實)"""
120
+ return self.vector_retriever.retrieve(
121
+ query=question,
122
+ top_k=top_k,
123
+ metadata_filter=metadata_filter
124
+ )
125
+
126
+ def _retrieve_step_back(self, question: str, top_k: int, metadata_filter: Optional[Dict] = None) -> tuple:
127
+ """Step-back 檢索(抽象原理)"""
128
+ abstract_question = self._generate_step_back_question(question)
129
+ results = self.vector_retriever.retrieve(
130
+ query=abstract_question,
131
+ top_k=top_k,
132
+ metadata_filter=metadata_filter
133
+ )
134
+ return results, abstract_question
135
+
136
+ def query(
137
+ self,
138
+ question: str,
139
+ top_k: int = 5,
140
+ metadata_filter: Optional[Dict] = None,
141
+ return_abstract_question: bool = False
142
+ ) -> Dict:
143
+ """
144
+ 執行雙軌檢索(不生成答案)
145
+
146
+ Args:
147
+ question: 原始問題
148
+ top_k: 每軌返回的結果數量
149
+ metadata_filter: 可選的 metadata 過濾條件
150
+ return_abstract_question: 是否返回抽象問題
151
+
152
+ Returns:
153
+ 包含雙軌檢索結果的字典
154
+ """
155
+ start_time = time.time()
156
+
157
+ if self.enable_parallel:
158
+ # 並行執行雙軌檢索
159
+ logger.info(f"🔄 並行執行雙軌檢索: '{question}'")
160
+ with ThreadPoolExecutor(max_workers=2) as executor:
161
+ direct_future = executor.submit(
162
+ self._retrieve_direct, question, top_k, metadata_filter
163
+ )
164
+ step_back_future = executor.submit(
165
+ self._retrieve_step_back, question, top_k, metadata_filter
166
+ )
167
+
168
+ specific_results = direct_future.result()
169
+ abstract_results, abstract_question = step_back_future.result()
170
+ else:
171
+ # 串行執行
172
+ logger.info(f"🔄 串行執行雙軌檢索: '{question}'")
173
+ specific_results = self._retrieve_direct(question, top_k, metadata_filter)
174
+ abstract_results, abstract_question = self._retrieve_step_back(question, top_k, metadata_filter)
175
+
176
+ elapsed_time = time.time() - start_time
177
+ logger.info(
178
+ f"✅ 雙軌檢索完成(耗時: {elapsed_time:.2f}s)\n"
179
+ f" 具體事實: {len(specific_results)} 個結果\n"
180
+ f" 抽象原理: {len(abstract_results)} 個結果"
181
+ )
182
+
183
+ return {
184
+ "specific_context": specific_results,
185
+ "abstract_context": abstract_results,
186
+ "abstract_question": abstract_question if return_abstract_question else None,
187
+ "question": question,
188
+ "elapsed_time": elapsed_time
189
+ }
190
+
191
+ def generate_answer(
192
+ self,
193
+ question: str,
194
+ formatter: PromptFormatter,
195
+ top_k: int = 5,
196
+ metadata_filter: Optional[Dict] = None,
197
+ document_type: str = "general",
198
+ return_abstract_question: bool = False
199
+ ) -> Dict:
200
+ """
201
+ 完整的 Step-back RAG 流程:雙軌檢索 -> 生成答案
202
+
203
+ Args:
204
+ question: 原始問題
205
+ formatter: Prompt 格式化器
206
+ top_k: 每軌用於生成答案的文檔數量
207
+ metadata_filter: 可選的 metadata 過濾條件
208
+ document_type: 文檔類型 ("paper", "cv", "general")
209
+ return_abstract_question: 是否返回抽象問題
210
+
211
+ Returns:
212
+ 包含檢索結果、生成的答案和統計資訊的字典
213
+ """
214
+ start_time = time.time()
215
+
216
+ # 第一步:雙軌檢索
217
+ retrieval_result = self.query(
218
+ question=question,
219
+ top_k=top_k,
220
+ metadata_filter=metadata_filter,
221
+ return_abstract_question=return_abstract_question
222
+ )
223
+
224
+ specific_results = retrieval_result["specific_context"]
225
+ abstract_results = retrieval_result["abstract_context"]
226
+
227
+ if not specific_results and not abstract_results:
228
+ return {
229
+ **retrieval_result,
230
+ "answer": "抱歉,未找到相關文檔來回答此問題。",
231
+ "formatted_context": None,
232
+ "answer_time": 0.0,
233
+ "total_time": retrieval_result["elapsed_time"]
234
+ }
235
+
236
+ # 第二步:格式化雙軌上下文
237
+ specific_context = formatter.format_context(
238
+ specific_results,
239
+ document_type=document_type
240
+ ) if specific_results else "未找到相關的具體事實資料。"
241
+
242
+ abstract_context = formatter.format_context(
243
+ abstract_results,
244
+ document_type=document_type
245
+ ) if abstract_results else "未找到相關的基礎原理資料。"
246
+
247
+ # 第三步:創建融合提示詞(關鍵步驟)
248
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
249
+
250
+ if is_chinese:
251
+ final_prompt = f"""你是一個資深專家。請結合以下兩類資訊來回答使用者的具體問題。
252
+
253
+ 【基礎原理與背景】
254
+ {abstract_context}
255
+
256
+ 【具體事實資料】
257
+ {specific_context}
258
+
259
+ 使用者問題:{question}
260
+
261
+ 請根據原理推導並結合事實,給出一個專業且具備邏輯的回答:
262
+ """
263
+ else:
264
+ final_prompt = f"""You are a senior expert. Please answer the user's specific question by combining the following two types of information.
265
+
266
+ 【Fundamental Principles and Background】
267
+ {abstract_context}
268
+
269
+ 【Specific Facts and Data】
270
+ {specific_context}
271
+
272
+ User question: {question}
273
+
274
+ Please provide a professional and logical answer based on principles and facts:
275
+ """
276
+
277
+ # 第四步:生成回答
278
+ logger.info("🤖 生成回答中...")
279
+ answer_start = time.time()
280
+ try:
281
+ answer = self.llm.generate(
282
+ prompt=final_prompt,
283
+ temperature=self.answer_temperature,
284
+ max_tokens=2048
285
+ )
286
+ answer_time = time.time() - answer_start
287
+ logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
288
+ except Exception as e:
289
+ logger.error(f"❌ 生成回答時出錯: {e}")
290
+ answer = f"生成回答時出錯: {e}"
291
+ answer_time = time.time() - answer_start
292
+
293
+ total_time = time.time() - start_time
294
+
295
+ return {
296
+ **retrieval_result,
297
+ "answer": answer,
298
+ "formatted_context": {
299
+ "specific": specific_context,
300
+ "abstract": abstract_context
301
+ },
302
+ "answer_time": answer_time,
303
+ "total_time": total_time
304
+ }
305
+
src/subquery_rag.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sub-query Decomposition RAG:將複雜問題拆解成子問題後檢索
3
+ """
4
+ from typing import List, Dict, Optional
5
+ from .retrievers.reranker import RAGPipeline
6
+ from .prompt_formatter import PromptFormatter
7
+ from .llm_integration import OllamaLLM
8
+ import hashlib
9
+ import time
10
+ import logging
11
+ from concurrent.futures import ThreadPoolExecutor, as_completed
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class SubQueryDecompositionRAG:
17
+ """使用子問題拆解的 RAG 系統"""
18
+
19
+ def __init__(
20
+ self,
21
+ rag_pipeline: RAGPipeline,
22
+ llm: OllamaLLM,
23
+ max_sub_queries: int = 3,
24
+ top_k_per_subquery: int = 5,
25
+ enable_parallel: bool = True
26
+ ):
27
+ """
28
+ 初始化 Sub-query Decomposition RAG
29
+
30
+ Args:
31
+ rag_pipeline: 現有的 RAG 管線實例
32
+ llm: LLM 實例(用於生成子問題)
33
+ max_sub_queries: 最多生成的子問題數量
34
+ top_k_per_subquery: 每個子問題檢索的結果數量
35
+ enable_parallel: 是否並行處理子查詢
36
+ """
37
+ self.rag_pipeline = rag_pipeline
38
+ self.llm = llm
39
+ self.max_sub_queries = max_sub_queries
40
+ self.top_k_per_subquery = top_k_per_subquery
41
+ self.enable_parallel = enable_parallel
42
+
43
+ def _generate_sub_queries(self, question: str) -> List[str]:
44
+ """
45
+ 將原始問題拆解成子問題
46
+
47
+ Args:
48
+ question: 原始問題
49
+
50
+ Returns:
51
+ 子問題列表
52
+ """
53
+ # 檢測語言
54
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
55
+
56
+ if is_chinese:
57
+ prompt = f"""你是一個專業助理。請將以下原始問題拆解成最多 {self.max_sub_queries} 個具體的子問題,以便進行資料搜尋。
58
+ 每個子問題應專注於原始問題的一個特定面向。請以換行符號分隔問題。
59
+
60
+ 原始問題: {question}
61
+
62
+ 子問題清單:"""
63
+ else:
64
+ prompt = f"""You are a professional assistant. Please decompose the following original question into at most {self.max_sub_queries} specific sub-questions for information retrieval.
65
+ Each sub-question should focus on a specific aspect of the original question. Please separate questions with newlines.
66
+
67
+ Original question: {question}
68
+
69
+ Sub-question list:"""
70
+
71
+ try:
72
+ response = self.llm.generate(
73
+ prompt=prompt,
74
+ temperature=0.3, # 降低溫度以獲得更穩定的結果
75
+ max_tokens=500
76
+ )
77
+
78
+ # 解析子問題
79
+ sub_queries = [
80
+ q.strip()
81
+ for q in response.strip().split("\n")
82
+ if q.strip() and not q.strip().startswith("#")
83
+ ]
84
+
85
+ # 移除編號前綴(如 "1. ", "1) " 等)
86
+ cleaned_queries = []
87
+ for q in sub_queries:
88
+ # 移除開頭的編號
89
+ q = q.lstrip("0123456789. )")
90
+ q = q.strip()
91
+ if q:
92
+ cleaned_queries.append(q)
93
+
94
+ # 限制數量
95
+ cleaned_queries = cleaned_queries[:self.max_sub_queries]
96
+
97
+ # 如果沒有生成子問題,使用原始問題
98
+ if not cleaned_queries:
99
+ logger.warning("⚠️ 未生成子問題,使用原始問題")
100
+ cleaned_queries = [question]
101
+
102
+ return cleaned_queries
103
+
104
+ except Exception as e:
105
+ logger.error(f"⚠️ 生成子問題時出錯: {e}")
106
+ # 回退到原始問題
107
+ return [question]
108
+
109
+ def _get_doc_id(self, doc: Dict) -> str:
110
+ """
111
+ 生成文檔的唯一標識符
112
+
113
+ Args:
114
+ doc: 文檔字典
115
+
116
+ Returns:
117
+ 唯一 ID
118
+ """
119
+ metadata = doc.get("metadata", {})
120
+ content = doc.get("content", "")
121
+
122
+ # 使用 metadata 中的唯一標識(如果有的話)
123
+ if "arxiv_id" in metadata and "chunk_index" in metadata:
124
+ return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
125
+ elif "file_path" in metadata and "chunk_index" in metadata:
126
+ return f"{metadata['file_path']}_{metadata['chunk_index']}"
127
+ else:
128
+ # 回退到內容的 hash
129
+ content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
130
+ return f"doc_{content_hash}"
131
+
132
+ def _retrieve_for_subquery(
133
+ self,
134
+ sub_query: str,
135
+ metadata_filter: Optional[Dict] = None
136
+ ) -> List[Dict]:
137
+ """
138
+ 針對單個子問題進行檢索
139
+
140
+ Args:
141
+ sub_query: 子問題
142
+ metadata_filter: 可選的 metadata 過濾條件
143
+
144
+ Returns:
145
+ 檢索結果列表
146
+ """
147
+ try:
148
+ results = self.rag_pipeline.query(
149
+ text=sub_query,
150
+ top_k=self.top_k_per_subquery,
151
+ metadata_filter=metadata_filter,
152
+ enable_rerank=True
153
+ )
154
+ return results
155
+ except Exception as e:
156
+ logger.error(f"⚠️ 檢索子問題 '{sub_query}' 時出錯: {e}")
157
+ return []
158
+
159
+ def _get_unique_documents(
160
+ self,
161
+ sub_queries: List[str],
162
+ metadata_filter: Optional[Dict] = None
163
+ ) -> List[Dict]:
164
+ """
165
+ 針對所有子問題進行檢索,並移除重複的檔案
166
+
167
+ Args:
168
+ sub_queries: 子問題列表
169
+ metadata_filter: 可選的 metadata 過濾條件
170
+
171
+ Returns:
172
+ 去重後的文檔列表
173
+ """
174
+ unique_docs = {}
175
+
176
+ if self.enable_parallel and len(sub_queries) > 1:
177
+ # 並行處理子查詢
178
+ logger.info(f"🔄 並行處理 {len(sub_queries)} 個子查詢...")
179
+ with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor:
180
+ future_to_query = {
181
+ executor.submit(self._retrieve_for_subquery, q, metadata_filter): q
182
+ for q in sub_queries
183
+ }
184
+
185
+ for future in as_completed(future_to_query):
186
+ sub_query = future_to_query[future]
187
+ try:
188
+ docs = future.result()
189
+ logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(docs)} 個結果")
190
+ for doc in docs:
191
+ doc_id = self._get_doc_id(doc)
192
+ if doc_id not in unique_docs:
193
+ unique_docs[doc_id] = doc
194
+ else:
195
+ # 如果已存在,保留分數更高的
196
+ existing_score = unique_docs[doc_id].get(
197
+ 'rerank_score',
198
+ unique_docs[doc_id].get('hybrid_score', 0)
199
+ )
200
+ new_score = doc.get(
201
+ 'rerank_score',
202
+ doc.get('hybrid_score', 0)
203
+ )
204
+ if new_score > existing_score:
205
+ unique_docs[doc_id] = doc
206
+ except Exception as e:
207
+ logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
208
+ else:
209
+ # 串行處理
210
+ logger.info(f"🔄 串行處理 {len(sub_queries)} 個子查詢...")
211
+ for sub_query in sub_queries:
212
+ docs = self._retrieve_for_subquery(sub_query, metadata_filter)
213
+ logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(docs)} 個結果")
214
+ for doc in docs:
215
+ doc_id = self._get_doc_id(doc)
216
+ if doc_id not in unique_docs:
217
+ unique_docs[doc_id] = doc
218
+ else:
219
+ # 保留分數更高的
220
+ existing_score = unique_docs[doc_id].get(
221
+ 'rerank_score',
222
+ unique_docs[doc_id].get('hybrid_score', 0)
223
+ )
224
+ new_score = doc.get(
225
+ 'rerank_score',
226
+ doc.get('hybrid_score', 0)
227
+ )
228
+ if new_score > existing_score:
229
+ unique_docs[doc_id] = doc
230
+
231
+ # 按分數排序
232
+ result_list = list(unique_docs.values())
233
+ result_list.sort(
234
+ key=lambda x: x.get('rerank_score', x.get('hybrid_score', 0)),
235
+ reverse=True
236
+ )
237
+
238
+ return result_list
239
+
240
+ def query(
241
+ self,
242
+ question: str,
243
+ top_k: int = 5,
244
+ metadata_filter: Optional[Dict] = None,
245
+ return_sub_queries: bool = False
246
+ ) -> Dict:
247
+ """
248
+ 執行 Sub-query Decomposition RAG 查詢
249
+
250
+ Args:
251
+ question: 原始問題
252
+ top_k: 返回前 k 個結果
253
+ metadata_filter: 可選的 metadata 過濾條件
254
+ return_sub_queries: 是否在結果中包含子問題列表
255
+
256
+ Returns:
257
+ 包含檢索結果和統計資訊的字典
258
+ """
259
+ start_time = time.time()
260
+
261
+ # 第一步:產生子問題
262
+ logger.info(f"🔍 拆解問題: '{question}'")
263
+ sub_queries = self._generate_sub_queries(question)
264
+ logger.info(f"✅ 生成 {len(sub_queries)} 個子問題:")
265
+ for i, sq in enumerate(sub_queries, 1):
266
+ logger.info(f" {i}. {sq}")
267
+
268
+ # 第二步:檢索並去重
269
+ logger.info(f"📚 檢索相關文檔...")
270
+ docs = self._get_unique_documents(sub_queries, metadata_filter)
271
+ logger.info(f"✅ 找到 {len(docs)} 個唯一文檔(去重後)")
272
+
273
+ # 第三步:返回前 top_k 個結果
274
+ final_results = docs[:top_k]
275
+
276
+ elapsed_time = time.time() - start_time
277
+
278
+ result = {
279
+ "results": final_results,
280
+ "total_docs_found": len(docs),
281
+ "sub_queries": sub_queries if return_sub_queries else None,
282
+ "elapsed_time": elapsed_time
283
+ }
284
+
285
+ return result
286
+
287
+ def generate_answer(
288
+ self,
289
+ question: str,
290
+ formatter: PromptFormatter,
291
+ top_k: int = 5,
292
+ metadata_filter: Optional[Dict] = None,
293
+ document_type: str = "general",
294
+ return_sub_queries: bool = False
295
+ ) -> Dict:
296
+ """
297
+ 完整的 Sub-query Decomposition RAG 流程:檢索 + 生成答案
298
+
299
+ Args:
300
+ question: 原始問題
301
+ formatter: Prompt 格式化器
302
+ top_k: 返回前 k 個結果用於生成答案
303
+ metadata_filter: 可選的 metadata 過濾條件
304
+ document_type: 文檔類型 ("paper", "cv", "general")
305
+ return_sub_queries: 是否在結果中包含子問題列表
306
+
307
+ Returns:
308
+ 包含檢索結果、生成的答案和統計資訊的字典
309
+ """
310
+ # 檢索
311
+ retrieval_result = self.query(
312
+ question=question,
313
+ top_k=top_k,
314
+ metadata_filter=metadata_filter,
315
+ return_sub_queries=return_sub_queries
316
+ )
317
+
318
+ if not retrieval_result["results"]:
319
+ return {
320
+ **retrieval_result,
321
+ "answer": "抱歉,未找到相關文檔來回答此問題。",
322
+ "formatted_context": None
323
+ }
324
+
325
+ # 格式化上下文
326
+ formatted_context = formatter.format_context(
327
+ retrieval_result["results"],
328
+ document_type=document_type
329
+ )
330
+
331
+ # 創建 prompt
332
+ prompt = formatter.create_prompt(
333
+ question,
334
+ formatted_context,
335
+ document_type=document_type
336
+ )
337
+
338
+ # 生成回答
339
+ logger.info("🤖 生成回答中...")
340
+ answer_start = time.time()
341
+ try:
342
+ answer = self.llm.generate(
343
+ prompt=prompt,
344
+ temperature=0.7,
345
+ max_tokens=2048
346
+ )
347
+ answer_time = time.time() - answer_start
348
+ logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
349
+ except Exception as e:
350
+ logger.error(f"❌ 生成回答時出錯: {e}")
351
+ answer = f"生成回答時出錯: {e}"
352
+ answer_time = time.time() - answer_start
353
+
354
+ return {
355
+ **retrieval_result,
356
+ "answer": answer,
357
+ "formatted_context": formatted_context,
358
+ "answer_time": answer_time,
359
+ "total_time": retrieval_result["elapsed_time"] + answer_time
360
+ }
361
+
src/triple_hybrid_rag.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Triple Hybrid RAG:融合 SubQuery + HyDE + Step-back Prompting
3
+ 結合三種技術的優勢,實現最強大的 RAG 系統
4
+ """
5
+ from typing import List, Dict, Optional
6
+ from .retrievers.reranker import RAGPipeline
7
+ from .retrievers.vector_retriever import VectorRetriever
8
+ from .prompt_formatter import PromptFormatter
9
+ from .llm_integration import OllamaLLM
10
+ import hashlib
11
+ import time
12
+ import logging
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class TripleHybridRAG:
19
+ """融合 SubQuery + HyDE + Step-back 的三重混合 RAG 系統"""
20
+
21
+ def __init__(
22
+ self,
23
+ rag_pipeline: RAGPipeline,
24
+ vector_retriever: VectorRetriever,
25
+ llm: OllamaLLM,
26
+ max_sub_queries: int = 3,
27
+ top_k_per_subquery: int = 5,
28
+ hypothetical_length: int = 200,
29
+ temperature_subquery: float = 0.3,
30
+ temperature_hyde: float = 0.7,
31
+ temperature_stepback: float = 0.3,
32
+ answer_temperature: float = 0.7,
33
+ enable_parallel: bool = True
34
+ ):
35
+ """
36
+ 初始化三重混合 RAG
37
+
38
+ Args:
39
+ rag_pipeline: RAG 管線實例
40
+ vector_retriever: 向量檢索器
41
+ llm: LLM 實例
42
+ max_sub_queries: 最多生成的子問題數量
43
+ top_k_per_subquery: 每個子問題檢索的結果數量
44
+ hypothetical_length: 假設性文檔目標長度(字符數)
45
+ temperature_subquery: 生成子問題的溫度(較低,更穩定)
46
+ temperature_hyde: 生成假設性文檔的溫度(較高,更多專業術語)
47
+ temperature_stepback: 生成抽象問題的溫度(較低,更穩定)
48
+ answer_temperature: 生成答案的溫度
49
+ enable_parallel: 是否並行處理
50
+ """
51
+ self.rag_pipeline = rag_pipeline
52
+ self.vector_retriever = vector_retriever
53
+ self.llm = llm
54
+ self.max_sub_queries = max_sub_queries
55
+ self.top_k_per_subquery = top_k_per_subquery
56
+ self.hypothetical_length = hypothetical_length
57
+ self.temperature_subquery = temperature_subquery
58
+ self.temperature_hyde = temperature_hyde
59
+ self.temperature_stepback = temperature_stepback
60
+ self.answer_temperature = answer_temperature
61
+ self.enable_parallel = enable_parallel
62
+
63
+ def _generate_sub_queries(self, question: str) -> List[str]:
64
+ """生成子問題(SubQuery)"""
65
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
66
+
67
+ if is_chinese:
68
+ prompt = f"""你是一個專業助理。請將以下原始問題拆解成最多 {self.max_sub_queries} 個具體的子問題,以便進行資料搜尋。
69
+ 每個子問題應專注於原始問題的一個特定面向。請以換行符號分隔問題。
70
+
71
+ 原始問題: {question}
72
+
73
+ 子問題清單:"""
74
+ else:
75
+ prompt = f"""You are a professional assistant. Please decompose the following original question into at most {self.max_sub_queries} specific sub-questions for information retrieval.
76
+ Each sub-question should focus on a specific aspect of the original question. Please separate questions with newlines.
77
+
78
+ Original question: {question}
79
+
80
+ Sub-question list:"""
81
+
82
+ try:
83
+ response = self.llm.generate(
84
+ prompt=prompt,
85
+ temperature=self.temperature_subquery,
86
+ max_tokens=500
87
+ )
88
+
89
+ sub_queries = [
90
+ q.strip()
91
+ for q in response.strip().split("\n")
92
+ if q.strip() and not q.strip().startswith("#")
93
+ ]
94
+
95
+ # 移除編號前綴
96
+ cleaned_queries = []
97
+ for q in sub_queries:
98
+ q = q.lstrip("0123456789. )")
99
+ q = q.strip()
100
+ if q:
101
+ cleaned_queries.append(q)
102
+
103
+ cleaned_queries = cleaned_queries[:self.max_sub_queries]
104
+
105
+ if not cleaned_queries:
106
+ logger.warning("⚠️ 未生成子問題,使用原始問題")
107
+ cleaned_queries = [question]
108
+
109
+ return cleaned_queries
110
+
111
+ except Exception as e:
112
+ logger.error(f"⚠️ 生成子問題時出錯: {e}")
113
+ return [question]
114
+
115
+ def _generate_hypothetical_document(self, sub_query: str) -> str:
116
+ """為子問題生成假設性文檔(HyDE)"""
117
+ is_chinese = PromptFormatter.detect_language(sub_query) == "zh"
118
+
119
+ if is_chinese:
120
+ prompt = f"""請針對以下問題,寫出一段約 {self.hypothetical_length} 字的專業技術檔案內容。
121
+ 這段內容應包含該領域常見的專業術語與原理說明,以便用於後續的語義檢索。
122
+ 請使用專業的術語和概念,即使你對某些細節不確定,也要包含相關的專業詞彙。
123
+
124
+ 問題: {sub_query}
125
+
126
+ 專業��術內容:"""
127
+ else:
128
+ prompt = f"""Please write a professional technical document of approximately {self.hypothetical_length} words in response to the following question.
129
+ This content should include common professional terminology and principle explanations in this field, to be used for subsequent semantic retrieval.
130
+ Please use professional terms and concepts, and include relevant professional vocabulary even if you are uncertain about some details.
131
+
132
+ Question: {sub_query}
133
+
134
+ Professional technical content:"""
135
+
136
+ try:
137
+ hypothetical_doc = self.llm.generate(
138
+ prompt=prompt,
139
+ temperature=self.temperature_hyde,
140
+ max_tokens=500
141
+ )
142
+
143
+ hypothetical_doc = hypothetical_doc.strip()
144
+
145
+ if not hypothetical_doc:
146
+ logger.warning(f"⚠️ 子問題 '{sub_query}' 的假設性文檔為空,使用子問題本身")
147
+ return sub_query
148
+
149
+ return hypothetical_doc
150
+
151
+ except Exception as e:
152
+ logger.error(f"⚠️ 生成假設性文檔時出錯: {e}")
153
+ return sub_query
154
+
155
+ def _generate_step_back_question(self, question: str) -> str:
156
+ """生成 Step-back 抽象問題"""
157
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
158
+
159
+ if is_chinese:
160
+ prompt = f"""你是一個資深專家。請將以下具體問題轉換為一個更抽象、更基礎的原理性問題。
161
+ 這個抽象問題應該幫助理解該領域的基礎概念和原理,而不是直接回答具體問題。
162
+
163
+ 具體問題: {question}
164
+
165
+ 請生成一個抽象問題,用於檢索相關的原理和背景知識:
166
+ """
167
+ else:
168
+ prompt = f"""You are a senior expert. Please convert the following specific question into a more abstract, fundamental question about principles and concepts.
169
+ This abstract question should help understand the basic concepts and principles in this field, rather than directly answering the specific question.
170
+
171
+ Specific question: {question}
172
+
173
+ Please generate an abstract question for retrieving relevant principles and background knowledge:
174
+ """
175
+
176
+ try:
177
+ abstract_question = self.llm.generate(
178
+ prompt=prompt,
179
+ temperature=self.temperature_stepback,
180
+ max_tokens=200
181
+ )
182
+
183
+ abstract_question = abstract_question.strip()
184
+
185
+ if not abstract_question:
186
+ logger.warning("⚠️ 生成的抽象問題為空,使用原始問題")
187
+ return question
188
+
189
+ return abstract_question
190
+
191
+ except Exception as e:
192
+ logger.error(f"⚠️ 生成抽象問題時出錯: {e}")
193
+ return question
194
+
195
+ def _get_doc_id(self, doc: Dict) -> str:
196
+ """生成文檔的唯一標識符"""
197
+ metadata = doc.get("metadata", {})
198
+ content = doc.get("content", "")
199
+
200
+ if "arxiv_id" in metadata and "chunk_index" in metadata:
201
+ return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
202
+ elif "file_path" in metadata and "chunk_index" in metadata:
203
+ return f"{metadata['file_path']}_{metadata['chunk_index']}"
204
+ else:
205
+ content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
206
+ return f"doc_{content_hash}"
207
+
208
+ def _process_subquery_with_hyde(
209
+ self,
210
+ sub_query: str,
211
+ metadata_filter: Optional[Dict] = None
212
+ ) -> tuple:
213
+ """處理單個子問題:生成假設性文檔並檢索"""
214
+ try:
215
+ hypothetical_doc = self._generate_hypothetical_document(sub_query)
216
+ results = self.vector_retriever.retrieve(
217
+ query=hypothetical_doc,
218
+ top_k=self.top_k_per_subquery,
219
+ metadata_filter=metadata_filter
220
+ )
221
+ return results, hypothetical_doc
222
+ except Exception as e:
223
+ logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
224
+ return [], ""
225
+
226
+ def query(
227
+ self,
228
+ question: str,
229
+ top_k: int = 5,
230
+ metadata_filter: Optional[Dict] = None,
231
+ return_sub_queries: bool = False,
232
+ return_hypothetical: bool = False,
233
+ return_abstract_question: bool = False
234
+ ) -> Dict:
235
+ """
236
+ 執行三重混合 RAG 檢索
237
+
238
+ 流程:
239
+ 1. 拆解成子問題(SubQuery)
240
+ 2. 對每個子問題生成假設性文檔並檢索(HyDE)
241
+ 3. 直接檢索原始問題(具體事實)
242
+ 4. 生成抽象問題並檢索(Step-back,抽象原理)
243
+ 5. 合併所有結果並去重
244
+ """
245
+ start_time = time.time()
246
+
247
+ # 第一步:生成子問題
248
+ logger.info(f"🔍 [SubQuery] 拆解問題: '{question}'")
249
+ sub_queries = self._generate_sub_queries(question)
250
+ logger.info(f"✅ 生成 {len(sub_queries)} 個子問題")
251
+
252
+ # 第二步:為每個子問題生成假設性文檔並檢索(HyDE)
253
+ logger.info(f"📚 [HyDE] 為每個子問題生成假設性文檔並檢索...")
254
+ subquery_results = []
255
+ hypothetical_docs = {}
256
+
257
+ if self.enable_parallel and len(sub_queries) > 1:
258
+ with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor:
259
+ future_to_query = {
260
+ executor.submit(self._process_subquery_with_hyde, sq, metadata_filter): sq
261
+ for sq in sub_queries
262
+ }
263
+
264
+ for future in as_completed(future_to_query):
265
+ sub_query = future_to_query[future]
266
+ try:
267
+ results, hypo_doc = future.result()
268
+ hypothetical_docs[sub_query] = hypo_doc
269
+ subquery_results.extend(results)
270
+ except Exception as e:
271
+ logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
272
+ else:
273
+ for sub_query in sub_queries:
274
+ results, hypo_doc = self._process_subquery_with_hyde(sub_query, metadata_filter)
275
+ hypothetical_docs[sub_query] = hypo_doc
276
+ subquery_results.extend(results)
277
+
278
+ # 第三步:Step-back 雙軌檢索
279
+ logger.info(f"🔍 [Step-back] 執行雙軌檢索...")
280
+
281
+ if self.enable_parallel:
282
+ with ThreadPoolExecutor(max_workers=2) as executor:
283
+ direct_future = executor.submit(
284
+ self.vector_retriever.retrieve,
285
+ question, top_k, metadata_filter
286
+ )
287
+ abstract_question = self._generate_step_back_question(question)
288
+ step_back_future = executor.submit(
289
+ self.vector_retriever.retrieve,
290
+ abstract_question, top_k, metadata_filter
291
+ )
292
+
293
+ specific_results = direct_future.result()
294
+ abstract_results = step_back_future.result()
295
+ else:
296
+ specific_results = self.vector_retriever.retrieve(
297
+ query=question,
298
+ top_k=top_k,
299
+ metadata_filter=metadata_filter
300
+ )
301
+ abstract_question = self._generate_step_back_question(question)
302
+ abstract_results = self.vector_retriever.retrieve(
303
+ query=abstract_question,
304
+ top_k=top_k,
305
+ metadata_filter=metadata_filter
306
+ )
307
+
308
+ # 第四步:合併所有結果並去重
309
+ logger.info(f"🔄 合併並去重所有檢索結果...")
310
+ all_results = subquery_results + specific_results + abstract_results
311
+ unique_docs = {}
312
+
313
+ for doc in all_results:
314
+ doc_id = self._get_doc_id(doc)
315
+ if doc_id not in unique_docs:
316
+ unique_docs[doc_id] = doc
317
+ else:
318
+ # 保留分數更高的
319
+ existing_score = unique_docs[doc_id].get('score', 0)
320
+ new_score = doc.get('score', 0)
321
+ if new_score > existing_score:
322
+ unique_docs[doc_id] = doc
323
+
324
+ # 排序並返回前 top_k
325
+ result_list = list(unique_docs.values())
326
+ result_list.sort(key=lambda x: x.get('score', 0), reverse=True)
327
+ final_results = result_list[:top_k]
328
+
329
+ elapsed_time = time.time() - start_time
330
+ logger.info(
331
+ f"✅ 三重混合檢索完成(耗時: {elapsed_time:.2f}s)\n"
332
+ f" 子問題檢索: {len(subquery_results)} 個結果\n"
333
+ f" 具體事實: {len(specific_results)} 個結果\n"
334
+ f" 抽象原理: {len(abstract_results)} 個結果\n"
335
+ f" 去重後總計: {len(result_list)} 個,返回前 {len(final_results)} 個"
336
+ )
337
+
338
+ return {
339
+ "results": final_results,
340
+ "total_docs_found": len(result_list),
341
+ "sub_queries": sub_queries if return_sub_queries else None,
342
+ "hypothetical_documents": hypothetical_docs if return_hypothetical else None,
343
+ "abstract_question": abstract_question if return_abstract_question else None,
344
+ "subquery_results": subquery_results,
345
+ "specific_context": specific_results,
346
+ "abstract_context": abstract_results,
347
+ "question": question,
348
+ "elapsed_time": elapsed_time
349
+ }
350
+
351
+ def generate_answer(
352
+ self,
353
+ question: str,
354
+ formatter: PromptFormatter,
355
+ top_k: int = 5,
356
+ metadata_filter: Optional[Dict] = None,
357
+ document_type: str = "general",
358
+ return_sub_queries: bool = False,
359
+ return_hypothetical: bool = False,
360
+ return_abstract_question: bool = False
361
+ ) -> Dict:
362
+ """
363
+ 完整的三重混合 RAG 流程:檢索 + 生成答案
364
+ """
365
+ start_time = time.time()
366
+
367
+ # 檢索
368
+ retrieval_result = self.query(
369
+ question=question,
370
+ top_k=top_k,
371
+ metadata_filter=metadata_filter,
372
+ return_sub_queries=return_sub_queries,
373
+ return_hypothetical=return_hypothetical,
374
+ return_abstract_question=return_abstract_question
375
+ )
376
+
377
+ if not retrieval_result["results"]:
378
+ return {
379
+ **retrieval_result,
380
+ "answer": "抱歉,未找到相關文檔來回答此問題。",
381
+ "formatted_context": None,
382
+ "answer_time": 0.0,
383
+ "total_time": retrieval_result["elapsed_time"]
384
+ }
385
+
386
+ # 格式化三類上下文
387
+ subquery_context = formatter.format_context(
388
+ retrieval_result["subquery_results"][:top_k],
389
+ document_type=document_type
390
+ ) if retrieval_result.get("subquery_results") else "未找到相關的子問題檢索結果。"
391
+
392
+ specific_context = formatter.format_context(
393
+ retrieval_result["specific_context"],
394
+ document_type=document_type
395
+ ) if retrieval_result.get("specific_context") else "未找到相關的具體事實資料。"
396
+
397
+ abstract_context = formatter.format_context(
398
+ retrieval_result["abstract_context"],
399
+ document_type=document_type
400
+ ) if retrieval_result.get("abstract_context") else "未找到相關的基礎原理資料。"
401
+
402
+ # 創建融合提示詞(關鍵步驟)
403
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
404
+
405
+ if is_chinese:
406
+ final_prompt = f"""你是一個資深專家。請結合以下三類資訊來回答使用者的具體問題。
407
+
408
+ 【基礎原理與背景】(來自 Step-back 抽象問題檢索)
409
+ {abstract_context}
410
+
411
+ 【具體事實資料】(來自直接問題檢索)
412
+ {specific_context}
413
+
414
+ 【子問題相關資料】(來自 SubQuery + HyDE 檢索)
415
+ {subquery_context}
416
+
417
+ 使用者問題:{question}
418
+
419
+ 請根據原理推導、結合具體事實,並參考子問題的相關資料,給出一個專業、全面且具備邏輯的回答:
420
+ """
421
+ else:
422
+ final_prompt = f"""You are a senior expert. Please answer the user's specific question by combining the following three types of information.
423
+
424
+ 【Fundamental Principles and Background】(from Step-back abstract question retrieval)
425
+ {abstract_context}
426
+
427
+ 【Specific Facts and Data】(from direct question retrieval)
428
+ {specific_context}
429
+
430
+ 【Sub-question Related Information】(from SubQuery + HyDE retrieval)
431
+ {subquery_context}
432
+
433
+ User question: {question}
434
+
435
+ Please provide a professional, comprehensive, and logical answer based on principles, facts, and sub-question related information:
436
+ """
437
+
438
+ # 生成回答
439
+ logger.info("🤖 生成回答中...")
440
+ answer_start = time.time()
441
+ try:
442
+ answer = self.llm.generate(
443
+ prompt=final_prompt,
444
+ temperature=self.answer_temperature,
445
+ max_tokens=2048
446
+ )
447
+ answer_time = time.time() - answer_start
448
+ logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
449
+ except Exception as e:
450
+ logger.error(f"❌ 生成回答時出錯: {e}")
451
+ answer = f"生成回答時出錯: {e}"
452
+ answer_time = time.time() - answer_start
453
+
454
+ total_time = time.time() - start_time
455
+
456
+ return {
457
+ **retrieval_result,
458
+ "answer": answer,
459
+ "formatted_context": {
460
+ "subquery": subquery_context,
461
+ "specific": specific_context,
462
+ "abstract": abstract_context
463
+ },
464
+ "answer_time": answer_time,
465
+ "total_time": total_time
466
+ }
467
+